Source code for nireports.interfaces.dmri

# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2023 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
#     https://www.nipreps.org/community/licensing/
#
"""Diffusion MRI -specific visualization."""

import nibabel as nb
import numpy as np
from nipype.interfaces.base import (
    BaseInterfaceInputSpec,
    File,
    SimpleInterface,
    TraitedSpec,
    isdefined,
    traits,
)
from nipype.utils.filemanip import fname_presuffix

from nireports.reportlets.modality.dwi import plot_heatmap


class _DWIHeatmapInputSpec(BaseInterfaceInputSpec):
    in_file = File(exists=True, mandatory=True, desc="a DWI file")
    scalarmap = File(exists=True, mandatory=True, desc="a DWI-derived scalar map")
    mask_file = File(exists=True, mandatory=True, desc="a 3D mask")
    b_indices = traits.Dict(
        traits.Int, traits.List(traits.Int), mandatory=True, desc="index of b values"
    )
    threshold = traits.Float(desc="a DWI intensity threshold")
    subsample = traits.Int(desc="a number of samples if subsampling the data")
    sigma = traits.Float(desc="standard deviation of the noise (for SNR conversions)")
    colormap = traits.Str("YlGn", usedefault=True, desc="colormap of the heatmap")
    scalarmap_label = traits.Str(desc="set the label designating the scalar map")
    bins = traits.Tuple((150, 11), traits.Int, traits.Int, usedefault=True, desc="bins")


class _DWIHeatmapOutputSpec(TraitedSpec):
    out_file = File(exists=True, desc="written file path")


[docs] class DWIHeatmap(SimpleInterface): """Prepare an dMRI summary plot for the report.""" input_spec = _DWIHeatmapInputSpec output_spec = _DWIHeatmapOutputSpec def _run_interface(self, runtime): gtable = self.inputs.b_indices threshold = None if not isdefined(self.inputs.threshold) else self.inputs.threshold subsample = None if not isdefined(self.inputs.subsample) else self.inputs.subsample sigma = None if not isdefined(self.inputs.sigma) else self.inputs.sigma scalarmap_label = ( None if not isdefined(self.inputs.scalarmap_label) else self.inputs.scalarmap_label ) out_figure = plot_heatmap( nb.load(self.inputs.in_file).get_fdata(dtype="float32"), list(gtable.values()), list(gtable.keys()), np.abs(np.asanyarray(nb.load(self.inputs.mask_file).dataobj) - 1) < 1e-3, nb.load(self.inputs.scalarmap).get_fdata(dtype="float32"), scalar_label=scalarmap_label, imax=threshold, sub_size=subsample, sigma=sigma, cmap=self.inputs.colormap, bins=self.inputs.bins, ) self._results["out_file"] = fname_presuffix( self.inputs.in_file, newpath=runtime.cwd, suffix="heatmap.svg", use_ext=False, ) out_figure.savefig(self._results["out_file"], format="svg", dpi=300) out_figure.clf() out_figure = None return runtime