import functools
import inspect
import logging
import os.path as op
from time import time
import nibabel as nib
from dipy.io.streamline import save_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram
from AFQ.data.s3bids import write_json
import numpy as np
from AFQ.tasks.utils import get_fname
from AFQ.utils.path import drop_extension
# These should only be used with pimms.calc
__all__ = ["as_file", "as_fit_deriv", "as_img"]
logger = logging.getLogger('AFQ.api')
# get args and kwargs from function
def get_args_and_kwargs(func):
param_dict = inspect.signature(func).parameters
param_list = func.__code__.co_varnames[
:func.__code__.co_argcount]
is_param_kwarg = {
name: name in param_dict and param_dict[name].default is
not param_dict[name].empty for name in param_list}
return param_list, is_param_kwarg, param_dict
# replaces *args and **kwargs with specific parameters from og_func
# so that pimms can see original parameter names after wrapping
# also adds on any args the decorator requires
# these will be extracted with extract_added_args
def has_args(og_func, needed_args):
def _has_args(func):
header = "def wrapper_has_args_func("
content = "):\n return func("
found_args = []
param_list, is_param_kwarg, param_dict = get_args_and_kwargs(og_func)
# add func args
for name in param_list:
if not is_param_kwarg[name]:
header += f"{name}, "
content += f"{name}, "
found_args.append(name)
# add decorator args
for arg in needed_args:
if arg not in found_args:
header += f"{arg}, "
content += f"{arg}, "
# add func kwargs
for name in param_list:
if is_param_kwarg[name]:
default = param_dict[name].default
if isinstance(default, str):
header += f"{name}='{default}', "
else:
header += f"{name}={default}, "
content += f"{name}={name}, "
header = header[:-2]
content = content[:-2]
content = f"{content})"
wrapper_has_args = header + content
scope = {"func": func}
exec(wrapper_has_args, scope)
return scope['wrapper_has_args_func']
return _has_args
# from function where needed args (like base_fname) are added,
# return length of args before added args, and the added args
def extract_added_args(func, names, args, includes=None):
vals = []
param_list, is_param_kwarg, _ = get_args_and_kwargs(func)
arg_list = [param for param in param_list if not is_param_kwarg[param]]
extra_count = 0
for jj, name in enumerate(names):
if includes is not None and not includes[jj]:
vals.append(None)
continue
found = False
for ii, arg_name in enumerate(arg_list):
if arg_name == name:
vals.append(args[ii])
found = True
break
if not found:
vals.append(args[len(arg_list) + extra_count])
extra_count = extra_count + 1
return len(arg_list), *vals
[docs]def as_file(suffix, include_track=False, include_seg=False):
"""
return img and meta as saved file path, with json,
and only run if not already found
"""
def _as_file(func):
needed_args = ["base_fname"]
if include_track:
needed_args.append("tracking_params")
if include_seg:
needed_args.append("segmentation_params")
@functools.wraps(func)
@has_args(func, needed_args)
def wrapper_as_file(*args, **kwargs):
og_arg_count, base_fname, tracking_params, segmentation_params =\
extract_added_args(
func,
["base_fname", "tracking_params", "segmentation_params"],
args,
includes=[True, include_track, include_seg])
this_file = get_fname(
base_fname, suffix,
tracking_params=tracking_params,
segmentation_params=segmentation_params)
if not op.exists(this_file):
img_trk_or_csv, meta = func(*args[:og_arg_count], **kwargs)
logger.info(f"Saving {this_file}")
if isinstance(img_trk_or_csv, nib.Nifti1Image):
nib.save(img_trk_or_csv, this_file)
elif isinstance(img_trk_or_csv, StatefulTractogram):
save_tractogram(
img_trk_or_csv, this_file, bbox_valid_check=False)
else:
img_trk_or_csv.to_csv(this_file)
meta_fname = get_fname(
base_fname, f"{drop_extension(suffix)}.json",
tracking_params=tracking_params,
segmentation_params=segmentation_params)
write_json(meta_fname, meta)
return this_file
return wrapper_as_file
return _as_file
[docs]def as_fit_deriv(tf_name):
"""
return data as nibabel image, meta with params information
"""
def _as_fit_deriv(func):
needed_args = ["dwi_affine", f"{tf_name.lower()}_params"]
@functools.wraps(func)
@has_args(func, needed_args)
def wrapper_as_fit_deriv(*args, **kwargs):
og_arg_count, dwi_affine, params = extract_added_args(
func, needed_args, args)
img = nib.Nifti1Image(
func(*args[:og_arg_count], **kwargs), dwi_affine)
return img, {f"{tf_name}ParamsFile": params}
return wrapper_as_fit_deriv
return _as_fit_deriv
[docs]def as_img(func):
"""
return data, meta as nibabel image, meta with timing
"""
needed_args = ["dwi_affine"]
@functools.wraps(func)
@has_args(func, needed_args)
def wrapper_as_img(*args, **kwargs):
og_arg_count, affine = extract_added_args(
func, needed_args, args)
start_time = time()
data, meta = func(*args[:og_arg_count], **kwargs)
meta['timing'] = time() - start_time
img = nib.Nifti1Image(data.astype(np.float32), affine)
return img, meta
return wrapper_as_img