import os.path as op
import os
import logging
from time import time
import numpy as np
from scipy.spatial.distance import cdist
from scipy.stats import zscore
from scipy.ndimage import binary_dilation
import dipy.tracking.streamline as dts
import dipy.tracking.streamlinespeed as dps
from dipy.segment.bundles import RecoBundles
from dipy.stats.analysis import gaussian_weights
from dipy.io.stateful_tractogram import StatefulTractogram, Space
from dipy.io.streamline import save_tractogram, load_tractogram
from dipy.utils.parallel import paramap
from dipy.segment.clustering import QuickBundles
from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
from dipy.segment.featurespeed import ResampleFeature
import AFQ.utils.models as ut
import AFQ.data.fetch as afd
from AFQ.data.utils import BUNDLE_RECO_2_AFQ
from AFQ.api.bundle_dict import BundleDict
from AFQ.definitions.mapping import ConformedFnirtMapping
from AFQ._fixes import gaussian_weights
__all__ = ["Segmentation", "clean_bundle", "clean_by_endpoints"]
logger = logging.getLogger('AFQ')
def _resample_tg(tg, n_points):
# reformat for dipy's set_number_of_points
if isinstance(tg, np.ndarray):
if len(tg.shape) > 2:
streamlines = tg.tolist()
streamlines = [np.asarray(item) for item in streamlines]
elif hasattr(tg, "streamlines"):
streamlines = tg.streamlines
else:
streamlines = tg
return dps.set_number_of_points(streamlines, n_points)
class _SlsBeingRecognized:
def __init__(self, sls, logger, save_intermediates, b_name, ref,
n_roi_dists):
self.oriented_yet = False
self.selected_fiber_idxs = np.arange(len(sls), dtype=np.uint32)
self.sls_flipped = np.zeros(len(sls), dtype=np.bool8)
self.bundle_vote = np.full(len(sls), -np.inf, dtype=np.float32)
self.logger = logger
self.start_time = -1
self.save_intermediates = save_intermediates
self.b_name = b_name
self.ref_sls = sls
self.ref = ref
self.n_roi_dists = n_roi_dists
def initiate_selection(self, clean_name):
self.start_time = time()
self.logger.info(f"Filtering by {clean_name}")
return np.zeros(len(self.selected_fiber_idxs), dtype=np.bool8)
def select(self, idx, clean_name, cut=False):
self.selected_fiber_idxs = self.selected_fiber_idxs[idx]
self.sls_flipped = self.sls_flipped[idx]
self.bundle_vote = self.bundle_vote[idx]
if hasattr(self, "roi_dists"):
self.roi_dists = self.roi_dists[idx]
time_taken = time() - self.start_time
self.logger.info(
f"After filtering by {clean_name} (time: {time_taken}s), "
f"{len(self)} streamlines remain.")
if self.save_intermediates is not None:
save_tractogram(
StatefulTractogram(
self.get_selected_sls(cut=cut),
self.ref, Space.VOX),
op.join(self.save_intermediates,
f'sls_after_{clean_name}_for_{self.b_name}.trk'),
bbox_valid_check=False)
def get_selected_sls(self, cut=False, flip=False):
selected_sls = self.ref_sls[self.selected_fiber_idxs]
if cut and hasattr(self, "roi_dists") and self.n_roi_dists > 1:
selected_sls = _cut_sls_by_dist(
selected_sls, self.roi_dists,
(0, self.n_roi_dists - 1),
in_place=False)
if flip:
selected_sls = _flip_sls(
selected_sls, self.sls_flipped,
in_place=False)
return selected_sls
def reorient(self, idx):
if self.oriented_yet:
raise RuntimeError((
"Attempted to oriented streamlines "
"that were already oriented. "
"This is a bug in the implementation of a "
"bundle recognition procedure. "))
self.oriented_yet = True
self.sls_flipped[idx] = True
def __bool__(self):
return len(self) > 0
def __len__(self):
return len(self.selected_fiber_idxs)
[docs]class Segmentation:
def __init__(self,
nb_points=False,
nb_streamlines=False,
seg_algo='AFQ',
clip_edges=False,
parallel_segmentation={"engine": "serial"},
progressive=True,
greater_than=50,
rm_small_clusters=50,
model_clust_thr=1.25,
reduction_thr=25,
refine=False,
pruning_thr=12,
b0_threshold=50,
prob_threshold=0,
roi_dist_tie_break=False,
dist_to_waypoint=None,
rng=None,
return_idx=False,
presegment_bundle_dict=None,
presegment_kwargs={},
filter_by_endpoints=True,
dist_to_atlas=4,
save_intermediates=None,
cleaning_params={}):
"""
Segment streamlines into bundles.
Parameters
----------
nb_points : int, boolean
Resample streamlines to nb_points number of points.
If False, no resampling is done. Default: False
nb_streamlines : int, boolean
Subsample streamlines to nb_streamlines.
If False, no subsampling is don. Default: False
seg_algo : string
Algorithm for segmentation (case-insensitive):
'AFQ': Segment streamlines into bundles,
based on inclusion/exclusion ROIs.
'Reco': Segment streamlines using the RecoBundles algorithm
[Garyfallidis2017].
Default: 'AFQ'
clip_edges : bool
Whether to clip the streamlines to be only in between the ROIs.
Default: False
parallel_segmentation : dict or AFQ.api.BundleDict
How to parallelize segmentation across processes when performing
waypoint ROI segmentation. Set to {"engine": "serial"} to not
perform parallelization. Some engines may cause errors, depending
on the system. See ``dipy.utils.parallel.paramap`` for
details.
Default: {"engine": "serial"}
rm_small_clusters : int
Using RecoBundles Algorithm.
Remove clusters that have less than this value
during whole brain SLR.
Default: 50
model_clust_thr : int
Parameter passed on to recognize for Recobundles.
See Recobundles documentation.
Default: 1.25
reduction_thr : int
Parameter passed on to recognize for Recobundles.
See Recobundles documentation.
Default: 25
refine : bool
Parameter passed on to recognize for Recobundles.
See Recobundles documentation.
Default: False
pruning_thr : int
Parameter passed on to recognize for Recobundles.
See Recobundles documentation.
Default: 12
progressive : boolean, optional
Using RecoBundles Algorithm.
Whether or not to use progressive technique
during whole brain SLR.
Default: True.
greater_than : int, optional
Using RecoBundles Algorithm.
Keep streamlines that have length greater than this value
during whole brain SLR.
Default: 50.
b0_threshold : float.
Using AFQ Algorithm.
All b-values with values less than or equal to `bo_threshold` are
considered as b0s i.e. without diffusion weighting.
Default: 50.
prob_threshold : float.
Using AFQ Algorithm.
Initial cleaning of fiber groups is done using probability maps
from [Hua2008]_. Here, we choose an average probability that
needs to be exceeded for an individual streamline to be retained.
Default: 0.
roi_dist_tie_break : bool.
Whether to use distance from nearest ROI as a tie breaker when a
streamline qualifies as a part of multiple bundles. If False,
probability maps are used.
Default : False.
dist_to_waypoint : float.
The distance that a streamline node has to be from the waypoint
ROI in order to be included or excluded.
If set to None (default), will be calculated as the
center-to-corner distance of the voxel in the diffusion data.
If a bundle has inc_addtol or exc_addtol in its bundle_dict, that
tolerance will be added to this distance.
For example, if you wanted to increase tolerance for the right
arcuate waypoint ROIs by 3 each, you could make the following
modification to your bundle_dict:
bundle_dict["Right Arcuate"]["inc_addtol"] = [3, 3]
Additional tolerances can also be negative.
rng : RandomState or int
If None, creates RandomState.
If int, creates RandomState with seed rng.
Used in RecoBundles Algorithm.
Default: None.
return_idx : bool
Whether to return the indices in the original streamlines as part
of the output of segmentation.
presegment_bundle_dict : dict or None
If not None, presegment by ROIs before performing
RecoBundles. Only used if seg_algo starts with 'Reco'.
Meta-data for the segmentation. The format is something like::
{'bundle_name': {
'include':[img1, img2],
'prob_map': img3,
'cross_midline': False,
'start': img4,
'end': img5}}
Default: None
presegment_kwargs : dict
Optional arguments for initializing the segmentation for the
presegmentation. Only used if presegment_bundle_dict is not None.
Default: {}
filter_by_endpoints: bool
Whether to filter the bundles based on their endpoints.
Applies only when `seg_algo == 'AFQ'`.
Default: True.
dist_to_atlas : float
If filter_by_endpoints is True, this is the required distance
from the endpoints to the atlas ROIs.
save_intermediates : str, optional
The full path to a folder into which intermediate products
are saved. Default: None, means no saving of intermediates.
cleaning_params : dict, optional
Cleaning params to pass to seg.clean_bundle. This will
override the default parameters of that method. However, this
can be overriden by setting the cleaning parameters in the
bundle_dict. Default: {}.
References
----------
.. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008)
Tract probability maps in stereotaxic spaces: analyses of white
matter anatomy and tract-specific quantification. Neuroimage 39:
336-347
"""
self.logger = logger
self.nb_points = nb_points
self.nb_streamlines = nb_streamlines
if rng is None:
self.rng = np.random.RandomState()
elif isinstance(rng, int):
self.rng = np.random.RandomState(rng)
else:
self.rng = rng
self.seg_algo = seg_algo.lower()
self.prob_threshold = prob_threshold
self.roi_dist_tie_break = roi_dist_tie_break
self.dist_to_waypoint = dist_to_waypoint
self.b0_threshold = b0_threshold
self.progressive = progressive
self.greater_than = greater_than
self.rm_small_clusters = rm_small_clusters
self.model_clust_thr = model_clust_thr
self.reduction_thr = reduction_thr
self.refine = refine
self.pruning_thr = pruning_thr
self.return_idx = return_idx
self.presegment_bundle_dict = presegment_bundle_dict
self.presegment_kwargs = presegment_kwargs
self.filter_by_endpoints = filter_by_endpoints
self.dist_to_atlas = dist_to_atlas
self.parallel_segmentation = parallel_segmentation
self.cleaning_params = cleaning_params
if (save_intermediates is not None) and \
(not op.exists(save_intermediates)):
os.makedirs(save_intermediates, exist_ok=True)
self.save_intermediates = save_intermediates
self.clip_edges = clip_edges
[docs] def _read_tg(self, tg=None):
if tg is None:
tg = self.tg
else:
self.tg = tg
self._tg_orig_space = self.tg.space
if self.nb_streamlines and len(self.tg) > self.nb_streamlines:
self.tg = StatefulTractogram.from_sft(
dts.select_random_set_of_streamlines(
self.tg.streamlines,
self.nb_streamlines
),
self.tg
)
return tg
[docs] def segment(self, bundle_dict, tg, mapping, img,
reg_prealign=None,
reg_template=None, reset_tg_space=False):
"""
Segment streamlines into bundles based on either waypoint ROIs
[Yeatman2012]_ or RecoBundles [Garyfallidis2017]_.
Parameters
----------
bundle_dict: dict or AFQ.api.BundleDict
Meta-data for the segmentation. The format is something like::
{'bundle_name': {
'include':[img1, img2],
'prob_map': img3,
'cross_midline': False,
'start': img4,
'end': img5}}
tg : StatefulTractogram
Bundles to segment
mapping : DiffeomorphicMap, or equivalent interface
A mapping between DWI space and a template.
img : Nifti1Image
Image to use as reference.
reg_prealign : array, optional.
The linear transformation to be applied to align input images to
the reference space before warping under the deformation field.
Default: None.
reg_template : str or nib.Nifti1Image, optional.
Template to use for registration. Default: MNI T2.
reset_tg_space : bool, optional
Whether to reset the space of the input tractogram after
segmentation is complete. Default: False.
Returns
-------
dict : Where keys are bundle names, values are tractograms of
these bundles.
References
----------
.. [Yeatman2012] Yeatman, Jason D., Robert F. Dougherty, Nathaniel J.
Myall, Brian A. Wandell, and Heidi M. Feldman. 2012. "Tract Profiles of
White Matter Properties: Automating Fiber-Tract Quantification"
PloS One 7 (11): e49790.
.. [Garyfallidis17] Garyfallidis et al. Recognition of white matter
bundles using local and global streamline-based registration and
clustering, Neuroimage, 2017.
"""
self.img_affine = img.affine
self.img = img
self.logger.info("Preprocessing Streamlines")
tg = self._read_tg(tg)
# These are calculated as-needed
self._fg_array = None
self._crosses = None
# If resampling over-write the sft:
if self.nb_points:
self.tg = StatefulTractogram(
dps.set_number_of_points(self.tg.streamlines, self.nb_points),
self.tg, self.tg.space)
if reg_template is None:
reg_template = afd.read_mni_template()
self.reg_prealign = reg_prealign
self.reg_template = reg_template
self.mapping = mapping
self.bundle_dict = bundle_dict
if not isinstance(self.bundle_dict, BundleDict):
self.bundle_dict = BundleDict(self.bundle_dict)
if self.seg_algo == "afq":
fiber_groups = self.segment_afq()
elif self.seg_algo.startswith("reco"):
fiber_groups = self.segment_reco()
else:
raise ValueError(f"The seg_algo input is {self.seg_algo}, which",
"is not recognized")
if reset_tg_space:
# Return the input to the original space when you are done:
self.tg.to_space(self._tg_orig_space)
return fiber_groups
@property
[docs] def fgarray(self):
"""
Streamlines resampled to 20 points.
"""
if self._fg_array is None:
self.logger.info("Resampling Streamlines...")
start_time = time()
self._fg_array = np.array(_resample_tg(self.tg, 20))
self.logger.info((
"Streamlines Resampled "
f"(time: {time()-start_time}s)"))
return self._fg_array
@property
[docs] def crosses(self):
"""
Classify the streamlines by whether they cross the midline.
Creates a crosses attribute which is an array of booleans. Each boolean
corresponds to a streamline, and is whether or not that streamline
crosses the midline.
"""
if self._crosses is None:
# What is the x,y,z coordinate of 0,0,0 in the template space?
zero_coord = np.dot(np.linalg.inv(self.img_affine),
np.array([0, 0, 0, 1]))
self._crosses = np.logical_and(
np.any(self.fgarray[:, :, 0] > zero_coord[0], axis=1),
np.any(self.fgarray[:, :, 0] < zero_coord[0], axis=1))
return self._crosses
[docs] def _return_empty(self, bundle):
"""
Helper function for segment_afq, to return an empty dict under
some conditions.
"""
if self.return_idx:
self.fiber_groups[bundle] = {}
self.fiber_groups[bundle]['sl'] = StatefulTractogram(
[], self.img, Space.VOX)
self.fiber_groups[bundle]['idx'] = np.array([])
else:
self.fiber_groups[bundle] = StatefulTractogram(
[], self.img, Space.VOX)
[docs] def _add_bundle_to_fiber_group(self, b_name, sl, idx, to_flip):
"""
Helper function for segment_afq, to add a bundle
to a fiber group.
"""
sl = _flip_sls(
sl, to_flip,
in_place=False)
sl = StatefulTractogram(
sl,
self.img,
Space.VOX)
if self.return_idx:
self.fiber_groups[b_name] = {}
self.fiber_groups[b_name]['sl'] = sl
self.fiber_groups[b_name]['idx'] = idx
else:
self.fiber_groups[b_name] = sl
[docs] def segment_afq(self, tg=None):
"""
Assign streamlines to bundles using the waypoint ROI approach
Parameters
----------
tg : StatefulTractogram class instance
"""
tg = self._read_tg(tg=tg)
tg.to_vox()
n_streamlines = len(tg)
bundle_votes = np.full(
(n_streamlines, len(self.bundle_dict)),
-np.inf, dtype=np.float32)
bundle_to_flip = np.zeros(
(n_streamlines, len(self.bundle_dict)),
dtype=np.bool8)
bundle_roi_dists = -np.ones(
(
n_streamlines,
len(self.bundle_dict),
self.bundle_dict.max_includes),
dtype=np.uint32)
self.fiber_groups = {}
self.meta = {}
# We need to calculate the size of a voxel, so we can transform
# from mm to voxel units:
R = self.img_affine[0:3, 0:3]
vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R))))
# Tolerance is set to the square of the distance to the corner
# because we are using the squared Euclidean distance in calls to
# `cdist` to make those calls faster.
if self.dist_to_waypoint is None:
tol = dts.dist_to_corner(self.img_affine)
else:
tol = self.dist_to_waypoint / vox_dim
dist_to_atlas = int(self.dist_to_atlas / vox_dim)
self.logger.info("Assigning Streamlines to Bundles")
for bundle_idx, bundle_name in enumerate(
self.bundle_dict.bundle_names):
self.logger.info(f"Finding Streamlines for {bundle_name}")
# Warp ROIs
self.logger.info(f"Preparing ROIs for {bundle_name}")
start_time = time()
bundle_def = dict(self.bundle_dict.get_b_info(bundle_name))
bundle_def.update(self.bundle_dict.transform_rois(
bundle_name,
self.mapping,
self.img_affine))
self.logger.info(f"Time to prep ROIs: {time()-start_time}s")
if "curvature" in bundle_def:
self.logger.info(f"Loading curvature...")
start_time = time()
if "sft" in bundle_def["curvature"]:
ref_sl = bundle_def["curvature"]["sft"]
else:
ref_sl = load_tractogram(
bundle_def["curvature"]["path"], "same",
bbox_valid_check=False)
moved_ref_sl = self.move_streamlines(
ref_sl, "subject")
moved_ref_sl.to_vox()
moved_ref_sl = moved_ref_sl.streamlines[0]
moved_ref_curve = sl_curve(
moved_ref_sl,
len(moved_ref_sl))
self.logger.info((
"Time to load curves: "
f"{time()-start_time}s"))
b_sls = _SlsBeingRecognized(
tg.streamlines, self.logger,
self.save_intermediates, bundle_name,
self.img, len(bundle_def.get("include", [])))
# filter by probability map
if "prob_map" in bundle_def:
b_sls.initiate_selection("Prob. Map")
# using entire fgarray here only because it is the first step
fiber_probabilities = dts.values_from_volume(
bundle_def["prob_map"].get_fdata(),
self.fgarray, np.eye(4))
fiber_probabilities = np.mean(fiber_probabilities, -1)
if not self.roi_dist_tie_break:
b_sls.bundle_vote = fiber_probabilities
b_sls.select(
fiber_probabilities > self.prob_threshold,
"Prob. Map")
elif not self.roi_dist_tie_break:
b_sls.bundle_vote = np.ones(len(b_sls))
if b_sls and "cross_midline" in bundle_def:
b_sls.initiate_selection("Cross Mid.")
accepted = self.crosses[b_sls.selected_fiber_idxs]
if not bundle_def["cross_midline"]:
accepted = np.invert(accepted)
b_sls.select(accepted, "Cross Mid.")
if b_sls and "start" in bundle_def:
accept_idx = b_sls.initiate_selection("startpoint")
clean_by_endpoints(
b_sls.get_selected_sls(),
bundle_def["start"],
0,
tol=dist_to_atlas,
flip_sls=b_sls.sls_flipped,
accepted_idxs=accept_idx)
if not b_sls.oriented_yet:
accepted_idx_flipped = clean_by_endpoints(
b_sls.get_selected_sls(),
bundle_def["start"],
-1,
tol=dist_to_atlas)
b_sls.reorient(accepted_idx_flipped)
accept_idx = np.logical_xor(
accepted_idx_flipped, accept_idx)
b_sls.select(accept_idx, "startpoint")
if b_sls and "end" in bundle_def:
accept_idx = b_sls.initiate_selection("endpoint")
cleaned_idx = clean_by_endpoints(
b_sls.get_selected_sls(),
bundle_def["end"],
-1,
tol=dist_to_atlas,
flip_sls=b_sls.sls_flipped,
accepted_idxs=accept_idx)
if not b_sls.oriented_yet:
accepted_idx_flipped = clean_by_endpoints(
b_sls.get_selected_sls(),
bundle_def["end"],
0,
tol=dist_to_atlas)
b_sls.reorient(accepted_idx_flipped)
accept_idx = np.logical_xor(
accepted_idx_flipped, accept_idx)
b_sls.select(accept_idx, "endpoint")
if b_sls and (
("min_len" in bundle_def) or ("max_len" in bundle_def)):
accept_idx = b_sls.initiate_selection("length")
min_len = bundle_def.get("min_len", 0) / vox_dim
max_len = bundle_def.get("max_len", np.inf) / vox_dim
for idx, sl in enumerate(b_sls.get_selected_sls()):
sl_len = np.sum(
np.linalg.norm(np.diff(sl, axis=0), axis=1))
if sl_len >= min_len and sl_len <= max_len:
accept_idx[idx] = 1
b_sls.select(accept_idx, "length")
if b_sls and "primary_axis" in bundle_def:
b_sls.initiate_selection("orientation")
accept_idx = clean_by_orientation(
b_sls.get_selected_sls(),
bundle_def["primary_axis"],
bundle_def.get(
"primary_axis_percentage", None))
b_sls.select(accept_idx, "orientation")
if b_sls and "include" in bundle_def:
accept_idx = b_sls.initiate_selection("include")
flip_using_include = len(bundle_def["include"]) > 1\
and not b_sls.oriented_yet
if f'inc_addtol' in bundle_def:
include_roi_tols = []
for inc_tol in bundle_def["inc_addtol"]:
include_roi_tols.append((inc_tol / vox_dim + tol)**2)
else:
include_roi_tols = [tol**2] * len(bundle_def["include"])
include_rois = []
for include_roi in bundle_def["include"]:
include_rois.append(np.array(
np.where(include_roi.get_fdata())).T)
# with parallel segmentation, the first for loop will
# only collect streamlines and does not need tqdm
if self.parallel_segmentation["engine"] != "serial":
inc_results = paramap(
_check_sl_with_inclusion, b_sls.get_selected_sls(),
func_args=[
include_rois, include_roi_tols],
**self.parallel_segmentation)
else:
inc_results = _check_sls_with_inclusion(
b_sls.get_selected_sls(),
include_rois,
include_roi_tols)
if self.roi_dist_tie_break:
min_dist_coords = np.ones(len(b_sls))
roi_dists = -np.ones(
(len(b_sls), self.bundle_dict.max_includes),
dtype=np.int32)
if flip_using_include:
to_flip = np.ones_like(accept_idx, dtype=np.bool8)
for sl_idx, inc_result in enumerate(inc_results):
sl_accepted, sl_dist = inc_result
if sl_accepted:
if self.roi_dist_tie_break:
min_dist_coords[sl_idx] = np.min(sl_dist)
if len(sl_dist) > 1:
roi_dists[sl_idx, :len(sl_dist)] = [
np.argmin(dist, 0)[0]
for dist in sl_dist]
first_roi_idx = roi_dists[sl_idx, 0]
last_roi_idx = roi_dists[
sl_idx, len(sl_dist) - 1]
# Only accept SLs that, when cut, are meaningful
if (len(sl_dist) < 2) or abs(
first_roi_idx - last_roi_idx) > 1:
# Flip sl if it is close to second ROI
# before its close to the first ROI
if flip_using_include:
to_flip[sl_idx] =\
first_roi_idx > last_roi_idx
if to_flip[sl_idx]:
roi_dists[sl_idx, :len(sl_dist)] =\
np.flip(roi_dists[
sl_idx, :len(sl_dist)])
accept_idx[sl_idx] = 1
else:
accept_idx[sl_idx] = 1
# see https://github.com/joblib/joblib/issues/945
if (
(self.parallel_segmentation.get(
"engine", "joblib") != "serial")
and (self.parallel_segmentation.get(
"backend", "loky") == "loky")):
from joblib.externals.loky import get_reusable_executor
get_reusable_executor().shutdown(wait=True)
if self.roi_dist_tie_break:
b_sls.bundle_vote = -min_dist_coords
b_sls.roi_dists = roi_dists
if flip_using_include:
b_sls.reorient(to_flip)
b_sls.select(accept_idx, "include")
# Filters streamlines by how well they match
# a curve in orientation and shape but not scale
if b_sls and "curvature" in bundle_def:
accept_idx = b_sls.initiate_selection("curvature")
ref_curve_threshold = np.radians(bundle_def["curvature"].get(
"thresh", 10))
cut = bundle_def["curvature"].get("cut", True)
for idx, sl in enumerate(b_sls.get_selected_sls(
cut=cut, flip=True)):
if len(sl) > 1:
this_sl_curve = sl_curve(sl, len(moved_ref_sl))
dist = sl_curve_dist(this_sl_curve, moved_ref_curve)
if dist <= ref_curve_threshold:
accept_idx[idx] = 1
b_sls.select(accept_idx, "curvature", cut=cut)
if b_sls and "exclude" in bundle_def:
accept_idx = b_sls.initiate_selection("exclude")
if f'exc_addtol' in bundle_def:
exclude_roi_tols = []
for exc_tol in bundle_def["exc_addtol"]:
exclude_roi_tols.append((exc_tol / vox_dim + tol)**2)
else:
exclude_roi_tols = [tol**2] * len(bundle_def["exclude"])
exclude_rois = []
for exclude_roi in bundle_def["exclude"]:
exclude_rois.append(np.array(
np.where(exclude_roi.get_fdata())).T)
for sl_idx, sl in enumerate(b_sls.get_selected_sls()):
if _check_sl_with_exclusion(
sl, exclude_rois, exclude_roi_tols):
accept_idx[sl_idx] = 1
b_sls.select(accept_idx, "exclude")
if b_sls and "qb_thresh" in bundle_def:
b_sls.initiate_selection("qb_thresh")
cut = self.clip_edges or ("bundlesection" in bundle_def)
qbx = QuickBundles(
bundle_def["qb_thresh"] / vox_dim,
AveragePointwiseEuclideanMetric(
ResampleFeature(nb_points=12)))
clusters = qbx.cluster(b_sls.get_selected_sls(
cut=cut, flip=True))
cleaned_idx = clusters[np.argmax(
clusters.clusters_sizes())].indices
b_sls.select(cleaned_idx, "qb_thresh", cut=cut)
if b_sls:
accept_idx = b_sls.initiate_selection("Mahalanobis")
clean_params = bundle_def.get("mahal", {})
clean_params = {
**self.cleaning_params,
**clean_params}
clean_params["return_idx"] = True
cut = self.clip_edges or ("bundlesection" in bundle_def)
_, cleaned_idx = clean_bundle(
b_sls.get_selected_sls(cut=cut, flip=True),
**clean_params)
b_sls.select(cleaned_idx, "Mahalanobis", cut=cut)
if b_sls and not b_sls.oriented_yet:
raise ValueError(
"pyAFQ was unable to consistently orient streamlines "
f"in bundle {bundle_name} using the provided ROIs. "
"This can be fixed by including at least 2 "
"waypoint ROIs, or by using "
"endpoint ROIs.")
if b_sls:
bundle_votes[
b_sls.selected_fiber_idxs,
bundle_idx] = b_sls.bundle_vote.copy()
bundle_to_flip[
b_sls.selected_fiber_idxs,
bundle_idx] = b_sls.sls_flipped.copy()
if hasattr(b_sls, "roi_dists"):
bundle_roi_dists[
b_sls.selected_fiber_idxs,
bundle_idx
] = b_sls.roi_dists.copy()
if self.save_intermediates is not None:
os.makedirs(self.save_intermediates, exist_ok=True)
bc_path = op.join(self.save_intermediates,
"sls_bundle_votes.npy")
np.save(bc_path, bundle_votes)
bundle_choice = np.argmax(bundle_votes, -1)
bundle_choice[bundle_votes.max(-1) == -np.inf] = -1
# We do another round through, so that we can orient all the
# streamlines within a bundle in the same orientation with respect to
# the ROIs. This order is ARBITRARY but CONSISTENT (going from ROI0
# to ROI1).
self.logger.info("Re-orienting streamlines to consistent directions")
for bundle_idx, bundle in enumerate(self.bundle_dict.bundle_names):
self.logger.info(f"Processing {bundle}")
select_idx = np.where(bundle_choice == bundle_idx)[0]
if len(select_idx) == 0:
# There's nothing here, set and move to the next bundle:
if "bundlesection" in self.bundle_dict.get_b_info(bundle):
for sb_name in self.bundle_dict.get_b_info(bundle)[
"bundlesection"]:
self._return_empty(sb_name)
else:
self._return_empty(bundle)
continue
# Use a list here, because ArraySequence doesn't support item
# assignment:
select_sl = list(tg.streamlines[select_idx])
roi_dists = bundle_roi_dists[select_idx, bundle_idx, :]
n_includes = len(self.bundle_dict.get_b_info(
bundle).get("include", []))
if self.clip_edges and n_includes > 1:
self.logger.info("Clipping Streamlines by ROI")
_cut_sls_by_dist(
select_sl, roi_dists,
(0, n_includes - 1), in_place=True)
to_flip = bundle_to_flip[select_idx, bundle_idx]
if "bundlesection" in self.bundle_dict[bundle]:
for sb_name, sb_include_cuts in self.bundle_dict.get_b_info(
bundle)["bundlesection"].items():
bundlesection_select_sl = _cut_sls_by_dist(
select_sl, roi_dists,
sb_include_cuts, in_place=False)
self._add_bundle_to_fiber_group(
sb_name, bundlesection_select_sl, select_idx, to_flip)
self._add_bundle_to_meta(sb_name, bundle_def)
else:
self._add_bundle_to_fiber_group(
bundle, select_sl, select_idx, to_flip)
self._add_bundle_to_meta(bundle, bundle_def)
return self.fiber_groups, self.meta
[docs] def move_streamlines(self, tg, to="template"):
"""Streamline-based registration of a whole-brain tractogram to
the MNI whole-brain atlas.
to : str
"template" or "subject"
"""
tg_og_space = tg.space
if isinstance(self.mapping, ConformedFnirtMapping):
if to != "subject":
raise ValueError(
"Attempted to transform streamlines to template using "
"unsupported mapping. "
"Use something other than Fnirt.")
tg.to_vox()
moved_sl = []
for sl in tg.streamlines:
moved_sl.append(self.mapping.transform_inverse_pts(sl))
else:
tg.to_rasmm()
if to == "template":
volume = self.mapping.forward
else:
volume = self.mapping.backward
delta = dts.values_from_volume(
volume,
tg.streamlines, np.eye(4))
moved_sl = dts.Streamlines(
[d + s for d, s in zip(delta, tg.streamlines)])
if to == "template":
ref = self.reg_template
else:
ref = self.img
moved_sft = StatefulTractogram(
moved_sl,
ref,
Space.RASMM)
if self.save_intermediates is not None:
save_tractogram(
moved_sft,
op.join(self.save_intermediates,
f'sls_in_{to}.trk'),
bbox_valid_check=False)
tg.to_space(tg_og_space)
return moved_sft
[docs] def segment_reco(self, tg=None):
"""
Segment streamlines using the RecoBundles algorithm [Garyfallidis2017]
Parameters
----------
tg : StatefulTractogram class instance
A whole-brain tractogram to be segmented.
Returns
-------
fiber_groups : dict
Keys are names of the bundles, values are Streamline objects.
The streamlines in each object have all been oriented to have the
same orientation (using `dts.orient_by_streamline`).
"""
tg = self._read_tg(tg=tg)
fiber_groups = {}
# We generate our instance of RB with the moved streamlines:
self.logger.info("Extracting Bundles")
# If doing a presegmentation based on ROIs then initialize
# that segmentation and segment using ROIs, else
# RecoBundles based on the whole brain tractogram
if self.presegment_bundle_dict is not None:
roiseg = Segmentation(**self.presegment_kwargs)
roiseg.segment(
self.presegment_bundle_dict,
self.tg,
self.mapping,
self.img,
reg_template=self.reg_template,
reg_prealign=self.reg_prealign)
roiseg_fg = roiseg.fiber_groups
else:
moved_sl = self.move_streamlines(tg).streamlines
rb = RecoBundles(moved_sl, verbose=False, rng=self.rng)
# Next we'll iterate over bundles, registering each one:
bundle_list = list(self.bundle_dict.keys())
if 'whole_brain' in bundle_list:
bundle_list.remove('whole_brain')
self.logger.info("Assigning Streamlines to Bundles")
for bundle in bundle_list:
self.logger.info(f"Finding streamlines for {bundle}")
b_info = self.bundle_dict[bundle]
model_sl = b_info['sl']
# If doing a presegmentation based on ROIs then initialize rb after
# Filtering the whole brain tractogram to pass through ROIs
if self.presegment_bundle_dict is not None:
afq_bundle_name = BUNDLE_RECO_2_AFQ.get(bundle, bundle)
if "return_idx" in self.presegment_kwargs\
and self.presegment_kwargs["return_idx"]:
indiv_tg = roiseg_fg[afq_bundle_name]['sl']
else:
indiv_tg = roiseg_fg[afq_bundle_name]
if len(indiv_tg.streamlines) < 1:
self.logger.warning((
f"No streamlines found by waypoint ROI "
f"pre-segmentation for {bundle}. Using entire"
f" tractography instead."))
indiv_tg = tg
# Now rb should be initialized based on the fiber group coming
# out of the roi segmentation
indiv_tg = StatefulTractogram(
indiv_tg.streamlines,
self.img,
Space.VOX)
indiv_tg.to_rasmm()
moved_sl = self.move_streamlines(indiv_tg).streamlines
rb = RecoBundles(
moved_sl,
verbose=False,
rng=self.rng)
if self.save_intermediates is not None:
if self.presegment_bundle_dict is not None:
moved_fname = f"{bundle}_presegmentation.trk"
else:
moved_fname = "whole_brain.trk"
moved_sft = StatefulTractogram(
moved_sl,
self.reg_template,
Space.RASMM)
save_tractogram(
moved_sft,
op.join(self.save_intermediates,
moved_fname),
bbox_valid_check=False)
model_sft = StatefulTractogram(
model_sl,
self.reg_template,
Space.RASMM)
save_tractogram(
model_sft,
op.join(self.save_intermediates,
f"{bundle}_model.trk"),
bbox_valid_check=False)
# Either whole brain tracgtogram or roi presegmented fiber group
# goes to rb.recognize
_, rec_labels = rb.recognize(model_bundle=model_sl,
model_clust_thr=self.model_clust_thr,
reduction_thr=self.reduction_thr,
reduction_distance='mdf',
slr=True,
slr_metric='asymmetric',
pruning_distance='mdf')
# Use the streamlines in the original space:
if self.presegment_bundle_dict is None:
recognized_sl = tg.streamlines[rec_labels]
else:
recognized_sl = indiv_tg.streamlines[rec_labels]
if self.refine and len(recognized_sl) > 0:
_, rec_labels = rb.refine(model_sl, recognized_sl,
self.model_clust_thr,
reduction_thr=self.reduction_thr,
pruning_thr=self.pruning_thr)
if self.presegment_bundle_dict is None:
recognized_sl = tg.streamlines[rec_labels]
else:
recognized_sl = indiv_tg.streamlines[rec_labels]
standard_sl = next(iter(b_info['centroid']))
oriented_sl = dts.orient_by_streamline(recognized_sl, standard_sl)
self.logger.info(
f"{len(oriented_sl)} streamlines selected with Recobundles")
if self.return_idx:
fiber_groups[bundle] = {}
fiber_groups[bundle]['idx'] = rec_labels
fiber_groups[bundle]['sl'] = StatefulTractogram(oriented_sl,
self.img,
Space.RASMM)
else:
fiber_groups[bundle] = StatefulTractogram(oriented_sl,
self.img,
Space.RASMM)
self.fiber_groups = fiber_groups
return fiber_groups, {}
def sl_curve(sl, n_points):
"""
Calculate the direction of the displacement between
each point along a streamline
Parameters
----------
sl : 2d array-like
Streamline to calcualte displacements for.
n_points : int
Number of points to resample the streamline to
Returns
-------
2d array of shape (len(sl)-1, 3) with displacements
between each point in sl normalized to 1.
"""
# Resample to a standardized number of points
resampled_sl = dps.set_number_of_points(
sl,
n_points)
# displacement at each point
resampled_sl_diff = np.diff(resampled_sl, axis=0)
# normalize this displacement
resampled_sl_diff = resampled_sl_diff / np.linalg.norm(
resampled_sl_diff, axis=1)[:, None]
return resampled_sl_diff
def sl_curve_dist(curve1, curve2):
"""
Calculate the mean angle using the directions of displacement
between two streamlines
Parameters
----------
curve1, curve2 : 2d array-like
Two curves calculated from sl_curve.
Returns
-------
The mean angle between each curve across all steps, in radians
"""
return np.mean(np.arccos(np.sum(curve1 * curve2, axis=1)))
[docs]def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3,
length_threshold=4, min_sl=20, stat='mean',
return_idx=False):
"""
Clean a segmented fiber group based on the Mahalnobis distance of
each streamline
Parameters
----------
tg : StatefulTractogram class instance or ArraySequence
A whole-brain tractogram to be segmented.
n_points : int, optional
Number of points to resample streamlines to.
Default: 100
clean_rounds : int, optional.
Number of rounds of cleaning based on the Mahalanobis distance from
the mean of extracted bundles. Default: 5
distance_threshold : float, optional.
Threshold of cleaning based on the Mahalanobis distance (the units are
standard deviations). Default: 3.
length_threshold: float, optional
Threshold for cleaning based on length (in standard deviations). Length
of any streamline should not be *more* than this number of stdevs from
the mean length.
min_sl : int, optional.
Number of streamlines in a bundle under which we will
not bother with cleaning outliers. Default: 20.
stat : callable or str, optional.
The statistic of each node relative to which the Mahalanobis is
calculated. Default: `np.mean` (but can also use median, etc.)
return_idx : bool
Whether to return indices in the original streamlines.
Default: False.
Returns
-------
A StatefulTractogram class instance containing only the streamlines
that have a Mahalanobis distance smaller than `clean_threshold` from
the mean of each one of the nodes.
"""
# Convert string to callable, if that's what you got.
if isinstance(stat, str):
stat = getattr(np, stat)
if hasattr(tg, "streamlines"):
streamlines = tg.streamlines
else:
streamlines = dts.Streamlines(tg)
# We don't even bother if there aren't enough streamlines:
if len(streamlines) < min_sl:
if return_idx:
return tg, np.arange(len(streamlines))
else:
return tg
# Resample once up-front:
fgarray = np.asarray(_resample_tg(streamlines, n_points))
# Keep this around, so you can use it for indexing at the very end:
idx = np.arange(len(fgarray))
# get lengths of each streamline
lengths = np.array([sl.shape[0] for sl in streamlines])
# We'll only do this for clean_rounds
rounds_elapsed = 0
idx_belong = idx
while (rounds_elapsed < clean_rounds) and (np.sum(idx_belong) > min_sl):
# Update by selection:
idx = idx[idx_belong]
fgarray = fgarray[idx_belong]
lengths = lengths[idx_belong]
rounds_elapsed += 1
# This calculates the Mahalanobis for each streamline/node:
m_dist = gaussian_weights(
fgarray, return_mahalnobis=True,
n_points=n_points, stat=stat)
logger.debug(f"Shape of fgarray: {np.asarray(fgarray).shape}")
logger.debug(f"Shape of m_dist: {m_dist.shape}")
logger.debug(f"Maximum m_dist: {np.max(m_dist)}")
logger.debug((
f"Maximum m_dist for each fiber: "
f"{np.max(m_dist, axis=1)}"))
length_z = zscore(lengths)
logger.debug(f"Shape of length_z: {length_z.shape}")
logger.debug(f"Maximum length_z: {np.max(length_z)}")
logger.debug((
"length_z for each fiber: "
f"{length_z}"))
if not (
np.any(m_dist > distance_threshold)
or np.any(length_z > length_threshold)):
break
# Select the fibers that have Mahalanobis smaller than the
# threshold for all their nodes:
idx_dist = np.all(m_dist < distance_threshold, axis=-1)
idx_len = length_z < length_threshold
idx_belong = np.logical_and(idx_dist, idx_len)
if np.sum(idx_belong) < min_sl:
# need to sort and return exactly min_sl:
idx_belong = np.argsort(np.sum(
m_dist, axis=-1))[:min_sl].astype(int)
logger.debug((
f"At rounds elapsed {rounds_elapsed}, "
"minimum streamlines reached"))
else:
idx_removed = idx_belong == 0
logger.debug((
f"Rounds elapsed: {rounds_elapsed}, "
f"num removed: {np.sum(idx_removed)}"))
logger.debug(f"Removed indicies: {np.where(idx_removed)[0]}")
# Select based on the variable that was keeping track of things for us:
if hasattr(tg, "streamlines"):
out = StatefulTractogram(tg.streamlines[idx], tg, Space.VOX)
else:
out = streamlines[idx]
if return_idx:
return out, idx
else:
return out
# Helper functions for segmenting using waypoint ROIs
# they are not a part of the class because we do not want
# copies of the class to be parallelized
def _check_sls_with_inclusion(sls, include_rois, include_roi_tols):
for sl in sls:
yield _check_sl_with_inclusion(
sl,
include_rois,
include_roi_tols)
def _check_sl_with_inclusion(sl, include_rois,
include_roi_tols):
"""
Helper function to check that a streamline is close to a list of
inclusion ROIS.
"""
dist = []
for ii, roi in enumerate(include_rois):
# Use squared Euclidean distance, because it's faster:
dist.append(cdist(sl, roi, 'sqeuclidean'))
if np.min(dist[-1]) > include_roi_tols[ii]:
# Too far from one of them:
return False, []
# Apparently you checked all the ROIs and it was close to all of them
return True, dist
def _check_sl_with_exclusion(sl, exclude_rois,
exclude_roi_tols):
""" Helper function to check that a streamline is not too close to a
list of exclusion ROIs.
"""
for ii, roi in enumerate(exclude_rois):
# Use squared Euclidean distance, because it's faster:
if np.min(cdist(sl, roi, 'sqeuclidean')) < exclude_roi_tols[ii]:
return False
# Either there are no exclusion ROIs, or you are not close to any:
return True
def _flip_sls(select_sl, idx_to_flip, in_place=False):
"""
Helper function to flip streamlines
"""
if in_place:
flipped_sl = select_sl
else:
flipped_sl = [None] * len(select_sl)
for ii, sl in enumerate(select_sl):
if idx_to_flip[ii]:
flipped_sl[ii] = sl[::-1]
else:
flipped_sl[ii] = sl
return flipped_sl
def _cut_sls_by_dist(select_sl, roi_dists, roi_idxs,
in_place=False):
"""
Helper function to cut streamlines according to which points
are closest to certain rois.
Parameters
----------
select_sl, streamlines to cut
roi_dists, distances from a given streamline to a given inclusion roi
roi_idxs, two indices into the list of inclusion rois to use for the cut
in_place, whether to modify select_sl
"""
if in_place:
cut_sl = select_sl
else:
cut_sl = [None] * len(select_sl)
for idx, this_sl in enumerate(select_sl):
if roi_idxs[0] == -1:
min0 = 0
else:
min0 = int(roi_dists[idx, roi_idxs[0]])
if roi_idxs[1] == -1:
min1 = len(this_sl)
else:
min1 = int(roi_dists[idx, roi_idxs[1]])
# handle if sls not flipped
if min0 > min1:
min0, min1 = min1, min0
# If the point that is closest to the first ROI
# is the same as the point closest to the second ROI,
# include the surrounding points to make a streamline.
if min0 == min1:
min1 = min1 + 1
min0 = min0 - 1
cut_sl[idx] = this_sl[min0:min1]
return cut_sl
def clean_by_orientation(streamlines, primary_axis, tol=None):
"""
Compute the cardinal orientation of each streamline
Parameters
----------
streamlines : sequence of N by 3 arrays
Where N is number of nodes in the array, the collection of
streamlines to filter down to.
Returns
-------
cleaned_idx, indicies of streamlines that passed cleaning,
logical_and of other two returns
along_accepted_idx, indices of streamlines that passed
cleaning along the bundle
end_accepted_idx, indices of streamlines that passed
cleaning based on difference between endpoints of bundle
"""
axis_diff = np.zeros((len(streamlines), 3))
endpoint_diff = np.zeros((len(streamlines), 3))
for ii, sl in enumerate(streamlines):
# endpoint diff is between first and last
endpoint_diff[ii, :] = np.abs(sl[0, :] - sl[-1, :])
# axis diff is difference between the nodes, along
axis_diff[ii, :] = np.sum(np.abs(np.diff(sl, axis=0)), axis=0)
orientation_along = np.argmax(axis_diff, axis=1)
along_accepted_idx = orientation_along == primary_axis
if tol is not None:
percentage_primary = 100 * axis_diff[:, primary_axis] / np.sum(
axis_diff, axis=1)
logger.debug((
"Maximum primary percentage found: "
f"{np.max(percentage_primary)}"))
along_accepted_idx = np.logical_and(
along_accepted_idx, percentage_primary > tol)
orientation_end = np.argmax(endpoint_diff, axis=1)
end_accepted_idx = orientation_end == primary_axis
cleaned_idx = np.logical_and(
along_accepted_idx,
end_accepted_idx)
return cleaned_idx
[docs]def clean_by_endpoints(streamlines, target, target_idx, tol=0,
flip_sls=None, accepted_idxs=None):
"""
Clean a collection of streamlines based on an endpoint ROI.
Filters down to only include items that have their start or end points
close to the targets.
Parameters
----------
streamlines : sequence of N by 3 arrays
Where N is number of nodes in the array, the collection of
streamlines to filter down to.
target: Nifti1Image
Nifti1Image containing a boolean representation of the ROI.
target_idx: int.
Index within each streamline to check if within the target region.
Typically 0 for startpoint ROIs or -1 for endpoint ROIs.
If using flip_sls, this becomes (len(sl) - this_idx - 1) % len(sl)
tol : int, optional
A distance tolerance (in units that the coordinates
of the streamlines are represented in). Default: 0, which means that
the endpoint is exactly in the coordinate of the target ROI.
flip_sls : 1d array, optional
Length is len(streamlines), whether to flip the streamline.
accepted_idxs : 1d array, optional
Boolean array, where entries correspond to eachs streamline,
and streamlines that pass cleaning will be set to 1.
Yields
-------
boolean array of streamlines that survive cleaning.
"""
if accepted_idxs is None:
accepted_idxs = np.zeros(len(streamlines), dtype=np.bool8)
if flip_sls is None:
flip_sls = np.zeros(len(streamlines))
flip_sls = flip_sls.astype(int)
roi = target.get_fdata()
if tol > 0:
roi = binary_dilation(
roi,
iterations=tol)
for ii, sl in enumerate(streamlines):
this_idx = target_idx
if flip_sls[ii]:
this_idx = (len(sl) - this_idx - 1) % len(sl)
xx, yy, zz = sl[this_idx].astype(int)
accepted_idxs[ii] = roi[xx, yy, zz]
return accepted_idxs