Source code for AFQ.api.participant

import nibabel as nib
import os.path as op
from time import time
import logging
from tqdm import tqdm
import numpy as np
import tempfile
import math

from PIL import Image, ImageDraw, ImageFont

from AFQ.definitions.mapping import SlrMap
from AFQ.api.utils import (
    check_attribute, AFQclass_doc,
    export_all_helper, valid_exports_string)

from AFQ.tasks.data import get_data_plan
from AFQ.tasks.mapping import get_mapping_plan
from AFQ.tasks.tractography import get_tractography_plan
from AFQ.tasks.segmentation import get_segmentation_plan
from AFQ.tasks.viz import get_viz_plan
from AFQ.tasks.utils import get_base_fname
from AFQ.utils.path import apply_cmd_to_afq_derivs
from AFQ.viz.utils import BEST_BUNDLE_ORIENTATIONS, trim, get_eye


__all__ = ["ParticipantAFQ"]


[docs]class ParticipantAFQ(object): f"""{AFQclass_doc}""" def __init__(self, dwi_data_file, bval_file, bvec_file, output_dir, **kwargs): """ Initialize a ParticipantAFQ object. Parameters ---------- dwi_data_file : str Path to DWI data file. bval_file : str Path to bval file. bvec_file : str Path to bvec file. output_dir : str Path to output directory. kwargs : additional optional parameters You can set additional parameters for any step of the process. See :ref:`usage/kwargs` for more details. Examples -------- api.ParticipantAFQ( dwi_data_file, bval_file, bvec_file, output_dir, csd_sh_order=4) api.ParticipantAFQ( dwi_data_file, bval_file, bvec_file, output_dir, reg_template_spec="mni_t2", reg_subject_spec="b0") Notes ----- In tracking_params, parameters with the suffix mask which are also an image from AFQ.definitions.image will be handled automatically by the api. """ if not isinstance(output_dir, str): raise TypeError( "output_dir must be a str") if not isinstance(dwi_data_file, str): raise TypeError( "dwi_data_file must be a str") if not isinstance(bval_file, str): raise TypeError( "bval_file must be a str") if not isinstance(bvec_file, str): raise TypeError( "bvec_file must be a str") if not op.exists(output_dir): raise ValueError( f"output_dir does not exist: {output_dir}") if "tractography_params" in kwargs: raise ValueError(( "unrecognized parameter tractography_params, " "did you mean tracking_params ?")) self.logger = logging.getLogger('AFQ') self.kwargs = dict( dwi_data_file=dwi_data_file, bval_file=bval_file, bvec_file=bvec_file, output_dir=output_dir, base_fname=get_base_fname(output_dir, dwi_data_file), **kwargs) self.make_workflow()
[docs] def make_workflow(self): # construct pimms plans if "mapping_definition" in self.kwargs and isinstance( self.kwargs["mapping_definition"], SlrMap): plans = { # if using SLR map, do tractography first "data": get_data_plan(self.kwargs), "tractography": get_tractography_plan( self.kwargs ), "mapping": get_mapping_plan( self.kwargs, use_sls=True ), "segmentation": get_segmentation_plan(self.kwargs), "viz": get_viz_plan(self.kwargs)} else: plans = { # Otherwise, do mapping first "data": get_data_plan(self.kwargs), "mapping": get_mapping_plan(self.kwargs), "tractography": get_tractography_plan( self.kwargs ), "segmentation": get_segmentation_plan(self.kwargs), "viz": get_viz_plan(self.kwargs)} # chain together a complete plan from individual plans previous_data = {} for name, plan in plans.items(): previous_data[f"{name}_imap"] = plan( **self.kwargs, **previous_data) self.wf_dict =\ previous_data[f"{name}_imap"]
[docs] def export(self, attr_name="help"): """ Export a specific output. To print a list of available outputs, call export without arguments. Parameters ---------- attr_name : str Name of the output to export. Default: "help" Returns ------- output : any The specific output, or None if called without arguments. """ section = check_attribute(attr_name) if section is None: return self.wf_dict[attr_name] return self.wf_dict[section][attr_name]
[docs] def export_up_to(self, attr_name="help"): f""" Export all derivatives necessary for a specific output. To print a list of available outputs, call export_up_to without arguments. {valid_exports_string} Parameters ---------- attr_name : str Name of the output to export up to. Default: "help" """ section = check_attribute(attr_name) wf_dict = self.wf_dict if section is not None: wf_dict = wf_dict[section] for dependent in wf_dict.plan.dependencies[attr_name]: self.export(dependent)
[docs] def export_all(self, viz=True, xforms=True, indiv=True): f""" Exports all the possible outputs {valid_exports_string} Parameters ---------- viz : bool Whether to output visualizations. This includes tract profile plots, a figure containing all bundles, and, if using the AFQ segmentation algorithm, individual bundle figures. Default: True xforms : bool Whether to output the reg_template image in subject space and, depending on if it is possible based on the mapping used, to output the b0 in template space. Default: True indiv : bool Whether to output individual bundles in their own files, in addition to the one file containing all bundles. If using the AFQ segmentation algorithm, individual ROIs are also output. Default: True """ start_time = time() seg_algo = self.export("segmentation_params").get("seg_algo", "AFQ") export_all_helper(self, seg_algo, xforms, indiv, viz) self.logger.info( f"Time taken for export all: {time() - start_time}")
[docs] def participant_montage(self, images_per_row=2): """ Generate montage of all bundles for a given subject. Parameters ---------- images_per_row : int Number of bundle images per row in output file. Default: 2 Returns ------- filename of montage images """ tdir = tempfile.gettempdir() all_fnames = [] bundle_dict = self.export("bundle_dict") self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") best_scalar = self.export(self.export("best_scalar")) size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): flip_axes = [False, False, False] for i in range(3): flip_axes[i] = (self.export("dwi_affine")[i, i] < 0) figure = viz_backend.visualize_volume( best_scalar, flip_axes=flip_axes, interact=False, inline=False) figure = viz_backend.visualize_bundles( self.export("bundles"), shade_by_volume=best_scalar, color_by_direction=True, flip_axes=flip_axes, bundle=bundle_name, figure=figure, interact=False, inline=False) view, direc = BEST_BUNDLE_ORIENTATIONS.get( bundle_name, ("Axial", "Top")) eye = get_eye(view, direc) this_fname = tdir + f"/t{ii}.png" if "plotly" in viz_backend.backend: figure.update_layout( scene_camera=dict( projection=dict(type="orthographic"), up={"x": 0, "y": 0, "z": 1}, eye=eye, center=dict(x=0, y=0, z=0)), showlegend=False) figure.write_image(this_fname, scale=4) # temporary fix for memory leak import plotly.io as pio pio.kaleido.scope._shutdown_kaleido() else: from dipy.viz import window direc = np.fromiter(eye.values(), dtype=int) data_shape = np.asarray( nib.load(self.export("b0")).get_fdata().shape) figure.set_camera( position=direc * data_shape, focal_point=data_shape // 2, view_up=(0, 0, 1)) figure.zoom(0.5) window.snapshot(figure, fname=this_fname, size=(600, 600)) def _save_file(curr_img): save_path = op.abspath(op.join( self.kwargs["output_dir"], "bundle_montage.png")) curr_img.save(save_path) all_fnames.append(save_path) this_img_trimmed = {} max_height = 0 max_width = 0 for ii, bundle_name in enumerate(bundle_dict): this_img = Image.open(tdir + f"/t{ii}.png") try: this_img_trimmed[ii] = trim(this_img) except IndexError: # this_img is a picture of nothing this_img_trimmed[ii] = this_img text_sz = 70 width, height = this_img_trimmed[ii].size height = height + text_sz result = Image.new( this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255)) result.paste(this_img_trimmed[ii], (0, text_sz)) this_img_trimmed[ii] = result draw = ImageDraw.Draw(this_img_trimmed[ii]) draw.text( (0, 0), bundle_name, (0, 0, 0), font=ImageFont.truetype( "Arial", text_sz)) if this_img_trimmed[ii].size[0] > max_width: max_width = this_img_trimmed[ii].size[0] if this_img_trimmed[ii].size[1] > max_height: max_height = this_img_trimmed[ii].size[1] curr_img = Image.new( 'RGB', (max_width * size[0], max_height * size[1]), color="white") for ii in range(len(bundle_dict)): x_pos = ii % size[0] _ii = ii // size[0] y_pos = _ii % size[1] _ii = _ii // size[1] this_img = this_img_trimmed[ii].resize((max_width, max_height)) curr_img.paste( this_img, (x_pos * max_width, y_pos * max_height)) _save_file(curr_img) return all_fnames
[docs] def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], suffix=""): """ Perform some command some or all outputs of pyafq. This is useful if you change a parameter and need to recalculate derivatives that depend on it. Some examples: cp, mv, rm . -r will be automtically added when necessary. Parameters ---------- cmd : str Command to run on outputs. Default: 'rm' dependent_on : str or None Which derivatives to perform command on . If None, perform on all. If "track", perform on all derivatives that depend on the tractography. If "recog", perform on all derivatives that depend on the bundle recognition. If "prof", perform on all derivatives that depend on the bundle profiling. Default: None exceptions : list of str Name outputs that the command should not be applied to. Default: [] suffix : str Parts of command that are used after the filename. Default: "" """ exception_file_names = [] for exception in exceptions: file_name = self.export(exception) if isinstance(file_name, str): exception_file_names.append(file_name) else: self.logger.warn(( f"The exception '{exception}' does not correspond" " to a filename and will be ignored.")) apply_cmd_to_afq_derivs( self.kwargs["output_dir"], self.export("base_fname"), cmd=cmd, exception_file_names=exception_file_names, suffix=suffix, dependent_on=dependent_on ) # do not assume previous calculations are still valid # after file operations self.make_workflow()
[docs] clobber = cmd_outputs # alias for default of cmd_outputs