import nibabel as nib
import os.path as op
from time import time
import logging
from tqdm import tqdm
import numpy as np
import tempfile
import math
from PIL import Image, ImageDraw, ImageFont
from AFQ.definitions.mapping import SlrMap
from AFQ.api.utils import (
    check_attribute, AFQclass_doc,
    export_all_helper, valid_exports_string)
from AFQ.tasks.data import get_data_plan
from AFQ.tasks.mapping import get_mapping_plan
from AFQ.tasks.tractography import get_tractography_plan
from AFQ.tasks.segmentation import get_segmentation_plan
from AFQ.tasks.viz import get_viz_plan
from AFQ.tasks.utils import get_base_fname
from AFQ.utils.path import apply_cmd_to_afq_derivs
from AFQ.viz.utils import BEST_BUNDLE_ORIENTATIONS, trim, get_eye
__all__ = ["ParticipantAFQ"]
[docs]class ParticipantAFQ(object):
    f"""{AFQclass_doc}"""
    def __init__(self,
                 dwi_data_file,
                 bval_file, bvec_file,
                 output_dir,
                 **kwargs):
        """
        Initialize a ParticipantAFQ object.
        Parameters
        ----------
        dwi_data_file : str
            Path to DWI data file.
        bval_file : str
            Path to bval file.
        bvec_file : str
            Path to bvec file.
        output_dir : str
            Path to output directory.
        kwargs : additional optional parameters
            You can set additional parameters for any step
            of the process. See :ref:`usage/kwargs` for more details.
        Examples
        --------
        api.ParticipantAFQ(
            dwi_data_file, bval_file, bvec_file, output_dir,
            csd_sh_order=4)
        api.ParticipantAFQ(
            dwi_data_file, bval_file, bvec_file, output_dir,
            reg_template_spec="mni_t2", reg_subject_spec="b0")
        Notes
        -----
        In tracking_params, parameters with the suffix mask which are also
        an image from AFQ.definitions.image will be handled automatically by
        the api.
        """
        if not isinstance(output_dir, str):
            raise TypeError(
                "output_dir must be a str")
        if not isinstance(dwi_data_file, str):
            raise TypeError(
                "dwi_data_file must be a str")
        if not isinstance(bval_file, str):
            raise TypeError(
                "bval_file must be a str")
        if not isinstance(bvec_file, str):
            raise TypeError(
                "bvec_file must be a str")
        if not op.exists(output_dir):
            raise ValueError(
                f"output_dir does not exist: {output_dir}")
        if "tractography_params" in kwargs:
            raise ValueError((
                "unrecognized parameter tractography_params, "
                "did you mean tracking_params ?"))
        self.logger = logging.getLogger('AFQ')
        self.kwargs = dict(
            dwi_data_file=dwi_data_file,
            bval_file=bval_file,
            bvec_file=bvec_file,
            output_dir=output_dir,
            base_fname=get_base_fname(output_dir, dwi_data_file),
            **kwargs)
        self.make_workflow()
[docs]    def make_workflow(self):
        # construct pimms plans
        if "mapping_definition" in self.kwargs and isinstance(
                self.kwargs["mapping_definition"], SlrMap):
            plans = {  # if using SLR map, do tractography first
                "data": get_data_plan(self.kwargs),
                "tractography": get_tractography_plan(
                    self.kwargs
                ),
                "mapping": get_mapping_plan(
                    self.kwargs,
                    use_sls=True
                ),
                "segmentation": get_segmentation_plan(self.kwargs),
                "viz": get_viz_plan(self.kwargs)}
        else:
            plans = {  # Otherwise, do mapping first
                "data": get_data_plan(self.kwargs),
                "mapping": get_mapping_plan(self.kwargs),
                "tractography": get_tractography_plan(
                    self.kwargs
                ),
                "segmentation": get_segmentation_plan(self.kwargs),
                "viz": get_viz_plan(self.kwargs)}
        # chain together a complete plan from individual plans
        previous_data = {}
        for name, plan in plans.items():
            previous_data[f"{name}_imap"] = plan(
                **self.kwargs,
                **previous_data)
        self.wf_dict =\
            
previous_data[f"{name}_imap"] 
[docs]    def export(self, attr_name="help"):
        """
        Export a specific output. To print a list of available outputs,
        call export without arguments.
        Parameters
        ----------
        attr_name : str
            Name of the output to export. Default: "help"
        Returns
        -------
        output : any
            The specific output, or None if called without arguments.
        """
        section = check_attribute(attr_name)
        if section is None:
            return self.wf_dict[attr_name]
        return self.wf_dict[section][attr_name] 
[docs]    def export_up_to(self, attr_name="help"):
        f"""
        Export all derivatives necessary for a specific output.
        To print a list of available outputs,
        call export_up_to without arguments.
        {valid_exports_string}
        Parameters
        ----------
        attr_name : str
            Name of the output to export up to. Default: "help"
        """
        section = check_attribute(attr_name)
        wf_dict = self.wf_dict
        if section is not None:
            wf_dict = wf_dict[section]
        for dependent in wf_dict.plan.dependencies[attr_name]:
            self.export(dependent) 
[docs]    def export_all(self, viz=True, xforms=True,
                   indiv=True):
        f""" Exports all the possible outputs
        {valid_exports_string}
        Parameters
        ----------
        viz : bool
            Whether to output visualizations. This includes tract profile
            plots, a figure containing all bundles, and, if using the AFQ
            segmentation algorithm, individual bundle figures.
            Default: True
        xforms : bool
            Whether to output the reg_template image in subject space and,
            depending on if it is possible based on the mapping used, to
            output the b0 in template space.
            Default: True
        indiv : bool
            Whether to output individual bundles in their own files, in
            addition to the one file containing all bundles. If using
            the AFQ segmentation algorithm, individual ROIs are also
            output.
            Default: True
        """
        start_time = time()
        seg_algo = self.export("segmentation_params").get("seg_algo", "AFQ")
        export_all_helper(self, seg_algo, xforms, indiv, viz)
        self.logger.info(
            f"Time taken for export all: {time() - start_time}") 
[docs]    def participant_montage(self, images_per_row=2):
        """
        Generate montage of all bundles for a given subject.
        Parameters
        ----------
        images_per_row : int
            Number of bundle images per row in output file.
            Default: 2
        Returns
        -------
        filename of montage images
        """
        tdir = tempfile.gettempdir()
        all_fnames = []
        bundle_dict = self.export("bundle_dict")
        self.logger.info("Generating Montage...")
        viz_backend = self.export("viz_backend")
        best_scalar = self.export(self.export("best_scalar"))
        size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row))
        for ii, bundle_name in enumerate(tqdm(bundle_dict)):
            flip_axes = [False, False, False]
            for i in range(3):
                flip_axes[i] = (self.export("dwi_affine")[i, i] < 0)
            figure = viz_backend.visualize_volume(
                best_scalar,
                flip_axes=flip_axes,
                interact=False,
                inline=False)
            figure = viz_backend.visualize_bundles(
                self.export("bundles"),
                shade_by_volume=best_scalar,
                color_by_direction=True,
                flip_axes=flip_axes,
                bundle=bundle_name,
                figure=figure,
                interact=False,
                inline=False)
            view, direc = BEST_BUNDLE_ORIENTATIONS.get(
                bundle_name, ("Axial", "Top"))
            eye = get_eye(view, direc)
            this_fname = tdir + f"/t{ii}.png"
            if "plotly" in viz_backend.backend:
                figure.update_layout(
                    scene_camera=dict(
                        projection=dict(type="orthographic"),
                        up={"x": 0, "y": 0, "z": 1},
                        eye=eye,
                        center=dict(x=0, y=0, z=0)),
                    showlegend=False)
                figure.write_image(this_fname, scale=4)
                # temporary fix for memory leak
                import plotly.io as pio
                pio.kaleido.scope._shutdown_kaleido()
            else:
                from dipy.viz import window
                direc = np.fromiter(eye.values(), dtype=int)
                data_shape = np.asarray(
                    nib.load(self.export("b0")).get_fdata().shape)
                figure.set_camera(
                    position=direc * data_shape,
                    focal_point=data_shape // 2,
                    view_up=(0, 0, 1))
                figure.zoom(0.5)
                window.snapshot(figure, fname=this_fname, size=(600, 600))
        def _save_file(curr_img):
            save_path = op.abspath(op.join(
                self.kwargs["output_dir"],
                "bundle_montage.png"))
            curr_img.save(save_path)
            all_fnames.append(save_path)
        this_img_trimmed = {}
        max_height = 0
        max_width = 0
        for ii, bundle_name in enumerate(bundle_dict):
            this_img = Image.open(tdir + f"/t{ii}.png")
            try:
                this_img_trimmed[ii] = trim(this_img)
            except IndexError:  # this_img is a picture of nothing
                this_img_trimmed[ii] = this_img
            text_sz = 70
            width, height = this_img_trimmed[ii].size
            height = height + text_sz
            result = Image.new(
                this_img_trimmed[ii].mode, (width, height),
                color=(255, 255, 255))
            result.paste(this_img_trimmed[ii], (0, text_sz))
            this_img_trimmed[ii] = result
            draw = ImageDraw.Draw(this_img_trimmed[ii])
            draw.text(
                (0, 0), bundle_name, (0, 0, 0),
                font=ImageFont.truetype(
                    "Arial", text_sz))
            if this_img_trimmed[ii].size[0] > max_width:
                max_width = this_img_trimmed[ii].size[0]
            if this_img_trimmed[ii].size[1] > max_height:
                max_height = this_img_trimmed[ii].size[1]
        curr_img = Image.new(
            'RGB',
            (max_width * size[0], max_height * size[1]),
            color="white")
        for ii in range(len(bundle_dict)):
            x_pos = ii % size[0]
            _ii = ii // size[0]
            y_pos = _ii % size[1]
            _ii = _ii // size[1]
            this_img = this_img_trimmed[ii].resize((max_width, max_height))
            curr_img.paste(
                this_img,
                (x_pos * max_width, y_pos * max_height))
        _save_file(curr_img)
        return all_fnames 
[docs]    def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[],
                    suffix=""):
        """
        Perform some command some or all outputs of pyafq.
        This is useful if you change a parameter and need
        to recalculate derivatives that depend on it.
        Some examples: cp, mv, rm .
        -r will be automtically added when necessary.
        Parameters
        ----------
        cmd : str
            Command to run on outputs. Default: 'rm'
        dependent_on : str or None
            Which derivatives to perform command on .
            If None, perform on all.
            If "track", perform on all derivatives that depend on the
            tractography.
            If "recog", perform on all derivatives that depend on the
            bundle recognition.
            If "prof", perform on all derivatives that depend on the
            bundle profiling.
            Default: None
        exceptions : list of str
            Name outputs that the command should not be applied to.
            Default: []
        suffix : str
            Parts of command that are used after the filename.
            Default: ""
        """
        exception_file_names = []
        for exception in exceptions:
            file_name = self.export(exception)
            if isinstance(file_name, str):
                exception_file_names.append(file_name)
            else:
                self.logger.warn((
                    f"The exception '{exception}' does not correspond"
                    " to a filename and will be ignored."))
        apply_cmd_to_afq_derivs(
            self.kwargs["output_dir"],
            self.export("base_fname"),
            cmd=cmd,
            exception_file_names=exception_file_names,
            suffix=suffix,
            dependent_on=dependent_on
        )
        # do not assume previous calculations are still valid
        # after file operations
        self.make_workflow() 
[docs]    clobber = cmd_outputs  # alias for default of cmd_outputs