Source code for spinalcordtoolbox.reports.slice

# -*- coding: utf-8 -*-

# TODO: Replace slice by spinalcordtoolbox.image.Slicer

import abc
import logging
import math

import numpy as np
from scipy import ndimage
from nibabel.nifti1 import Nifti1Image

from spinalcordtoolbox.image import Image
from spinalcordtoolbox.resampling import resample_nib
from spinalcordtoolbox.centerline.core import ParamCenterline, get_centerline

logger = logging.getLogger(__name__)


[docs]class Slice(object): """Abstract class representing slicing applied to >=1 volumes for the purpose of generating ROI slices. Typically, the first volumes are images, while the last volume is a segmentation, which is used as overlay on the image, and/or to retrieve the center of mass to center the image on each QC mosaic square. For convenience, the volumes are all brought in the SAL reference frame. Functions with the suffix `_slice` gets a slice cut in the desired axis at the "i" position of the data of the 3D image. While the functions with the suffix `_dim` gets the size of the desired dimension of the 3D image. IMPORTANT: Convention for orientation is "SAL" """ __metaclass__ = abc.ABCMeta def __init__(self, images, p_resample=0.6): """ :param images: list of 3D volumes to be separated into slices. """ logger.info('Resample images to {}x{} mm'.format(p_resample, p_resample)) self._images = list() image_ref = None # first pass: we don't have a reference image to resample to for i, image in enumerate(images): img = image.copy() img.change_orientation('SAL') if p_resample: if i == len(images) - 1: # Last volume corresponds to a segmentation, therefore use linear interpolation here type_img = 'seg' else: # Otherwise it's an image: use spline interpolation type_img = 'im' img_r = self._resample_slicewise(img, p_resample, type_img=type_img, image_ref=image_ref) else: img_r = img.copy() self._images.append(img_r) image_ref = self._images[0] # 2nd and next passes: we resample any image to the space of the first one @staticmethod def axial_slice(data, i): return data[i, :, :] @staticmethod def axial_dim(image): nx, ny, nz, nt, px, py, pz, pt = image.dim return nx @staticmethod def axial_aspect(image): nx, ny, nz, nt, px, py, pz, pt = image.dim return py / pz @staticmethod def sagittal_slice(data, i): return data[:, :, int(i)] @staticmethod def sagittal_dim(image): nx, ny, nz, nt, px, py, pz, pt = image.dim return nz @staticmethod def sagittal_aspect(image): nx, ny, nz, nt, px, py, pz, pt = image.dim return px / py @staticmethod def coronal_slice(data, i): return data[:, i, :] @staticmethod def coronal_dim(image): nx, ny, nz, nt, px, py, pz, pt = image.dim return ny @staticmethod def coronal_aspect(image): nx, ny, nz, nt, px, py, pz, pt = image.dim return px / pz @abc.abstractmethod def get_aspect(self, image): return
[docs] @staticmethod def crop(matrix, x, y, width, height): """Crops the matrix to width and height from the center Select the size of the matrix if the calculated crop `width` or `height` are larger than the size of the matrix. TODO : Move this into the Axial class :param matrix: Array representation of the image :param x: The center of the crop area in the x axis :param y: The center of the crop area in the y axis :param width: The width from the center :param height: The height from the center :returns: cropped matrix """ if width * 2 > matrix.shape[0]: width = matrix.shape[0] // 2 if height * 2 > matrix.shape[1]: height = matrix.shape[1] // 2 if x < width: x = width if y < height: y = height start_row = x - width end_row = start_row + width * 2 start_col = y - height end_col = start_col + height * 2 return matrix[start_row:end_row, start_col:end_col]
[docs] @staticmethod def add_slice(matrix, i, column, size, patch): """Adds a slice to the canvas containing all the slices TODO : Move this to the Axial class :param matrix: input/output "big canvas" :param i: slice position :param column: number of columns in mosaic :param size: :param patch: patch to insert :return: matrix """ start_col = (i % column) * size * 2 end_col = start_col + patch.shape[1] start_row = int(i / column) * size * 2 end_row = start_row + patch.shape[0] matrix[start_row:end_row, start_col:end_col] = patch return matrix
[docs] @staticmethod def nan_fill(A): """Interpolate NaN values with neighboring values in array (in-place) If only NaNs, return an array of zeros. """ nans = np.isnan(A) if ~np.any(nans): return A elif np.all(nans): A[:] = np.zeros_like(A) return A xp = (~nans).ravel().nonzero()[0] fp = A[~nans] x = nans.ravel().nonzero()[0] A[nans] = np.interp(x, xp, fp) return A
[docs] @abc.abstractmethod def get_name(self): """Get the class name""" return
[docs] @abc.abstractmethod def get_slice(self, data, i): """Abstract method to obtain a slice of a 3d matrix :param data: volume :param i: position to slice :return: 2D slice """ return
[docs] @abc.abstractmethod def get_dim(self, image): """Abstract method to obtain the depth of the 3d matrix. :param image: input Image :returns: numpy.ndarray """ return
def _axial_center(self, image): """Gets the center of mass in the axial plan :param image : input Image :returns: centers of mass in the x and y axis (tuple of numpy.ndarray of int) """ logger.info('Compute center of mass at each slice') data = np.array(image.data) # we cast np.array to overcome problem if inputing nii format nz = image.dim[0] # SAL orientation centers_x = np.zeros(nz) centers_y = np.zeros(nz) for i in range(nz): centers_x[i], centers_y[i] = ndimage.measurements.center_of_mass(data[i, :, :]) try: Slice.nan_fill(centers_x) Slice.nan_fill(centers_y) except ValueError as err: logger.error("Axial center of the spinal cord is not found: %s", err) raise return centers_x, centers_y
[docs] def mosaic(self, nb_column=0, size=15, return_center=False): """Obtain matrices of the mosaics Calculates how many squares will fit in a row based on the column and the size Multiply by 2 because the sides are of size*2. Central point is size +/-. :param nb_column: number of mosaic columns :param size: each column size :return: tuple of numpy.ndarray containing the mosaics of each slice pixels :return: list of tuples, each tuple representing the center of each square of the mosaic. Only with param return_center is True """ # Calculate number of columns to display on the report dim = self.get_dim(self._images[0]) # dim represents the 3rd dimension of the 3D matrix if nb_column == 0: nb_column = 600 // (size * 2) nb_row = math.ceil(dim // nb_column) + 1 # Compute the matrix size of the final mosaic image matrix_sz = (int(size * 2 * nb_row), int(size * 2 * nb_column)) centers_mosaic = [] for irow in range(nb_row): for icol in range(nb_column): centers_mosaic.append((icol * size * 2 + size, irow * size * 2 + size)) # Get center of mass for each slice of the image. If the input is the cord segmentation, these coordinates are # used to center the image on each panel of the mosaic. centers_x, centers_y = self.get_center() matrices = list() for image in self._images: matrix = np.zeros(matrix_sz) for i in range(dim): x = int(centers_x[i]) y = int(centers_y[i]) # crop slice around center of mass and add slice to the matrix layout # TODO: resample there after cropping based on physical dimensions self.add_slice(matrix, i, nb_column, size, self.crop(self.get_slice(image.data, i), x, y, size, size)) matrices.append(matrix) if return_center is True: return matrices, centers_mosaic else: return matrices
[docs] def single(self): """Obtain the matrices of the single slices. Flatten :returns: tuple of numpy.ndarray, matrix of the input 3D MRI containing the slices and matrix of the transformed 3D MRI to output containing the slices """ assert len(set([x.data.shape for x in self._images])) == 1, "Volumes don't have the same size" matrices = list() # Retrieve the L-R center of the slice for each row (i.e. in the S-I direction). index = self.get_center_spit() # Loop across images and generate matrices for the image and the overlay for image in self._images: # Initialize matrix with zeros. This matrix corresponds to the 2d slice shown on the QC report. matrix = np.zeros(image.dim[0:2]) for row in range(len(index)): # For each slice, translate in the R-L direction to center the cord matrix[row] = self.get_slice(image.data, int(np.round(index[row])))[row] matrices.append(matrix) return matrices
def aspect(self): return [self.get_aspect(x) for x in self._images] def _resample_slicewise(self, image, p_resample, type_img, image_ref=None): """ Resample at a fixed resolution to make sure the cord always appears with similar scale, regardless of the native resolution of the image. Assumes SAL orientation. :param image: Image() to resample :param p_resample: float: Resampling resolution in mm :param type_img: {'im', 'seg'}: If im, interpolate using spline. If seg, interpolate using linear then binarize. :param image_ref: Destination Image() to resample image to. :return: """ dict_interp = {'im': 'spline', 'seg': 'linear'} # Create nibabel object nii = Nifti1Image(image.data, image.hdr.get_best_affine()) # If no reference image is provided, resample to specified resolution if image_ref is None: # Resample to px x p_resample x p_resample mm (orientation is SAL by convention in QC module) nii_r = resample_nib(nii, new_size=[image.dim[4], p_resample, p_resample], new_size_type='mm', interpolation=dict_interp[type_img]) # Otherwise, resampling to the space of the reference image else: # Create nibabel object for reference image nii_ref = Nifti1Image(image_ref.data, image_ref.hdr.get_best_affine()) nii_r = resample_nib(nii, image_dest=nii_ref, interpolation=dict_interp[type_img]) # If resampled image is a segmentation, binarize using threshold at 0.5 if type_img == 'seg': img_r_data = (nii_r.get_data() > 0.5) * 1 else: img_r_data = nii_r.get_data() # Create Image objects image_r = Image(img_r_data, hdr=nii_r.header, dim=nii_r.header.get_data_shape()). \ change_orientation(image.orientation) return image_r
[docs]class Axial(Slice): """The axial representation of a slice"""
[docs] def get_name(self): return Axial.__name__
def get_aspect(self, image): return Slice.axial_aspect(image)
[docs] def get_slice(self, data, i): return self.axial_slice(data, i)
[docs] def get_dim(self, image): return self.axial_dim(image)
[docs] def get_center(self, img_idx=-1): """Get the center of mass of each slice. By default, it assumes that self._images is a list of images, and the last item is the segmentation from which the center of mass is computed.""" image = self._images[img_idx] return self._axial_center(image)
[docs]class Sagittal(Slice): """The sagittal representation of a slice"""
[docs] def get_name(self): return Sagittal.__name__
def get_aspect(self, image): return Slice.sagittal_aspect(image)
[docs] def get_slice(self, data, i): return self.sagittal_slice(data, i)
[docs] def get_dim(self, image): return self.sagittal_dim(image)
[docs] def get_center_spit(self, img_idx=-1): """ Retrieve index along in the R-L direction for each S-I slice in order to center the spinal cord in the medial plane, around the labels or segmentation. By default, it looks at the latest image in the input list of images, assuming the latest is the labels or segmentation. If only one label is found, the cord will be centered at that label. :return: index: [int] * n_SI """ image = self._images[img_idx].copy() assert image.orientation == 'SAL' # If mask is empty, raise error if np.argwhere(image.data).shape[0] == 0: raise ValueError("Label/segmentation image is empty. Can't retrieve RL slice indices.") # If mask only has one label (e.g., in sct_detect_pmj), return the R-L index (repeated n_SI times) elif np.argwhere(image.data).shape[0] == 1: return [np.argwhere(image.data)[0][2]] * image.data.shape[0] # SAL orientation, so shape[0] -> SI axis # Otherwise, find the center of mass of each label (per axial plane) and extrapolate linearly else: image.change_orientation('RPI') # need to do that because get_centerline operates in RPI orientation # Get coordinate of centerline # Here we use smooth=0 because we want the centerline to pass through the labels, and minmax=True extends # the centerline below zmin and above zmax to avoid discontinuities data_ctl_RPI, _, _, _ = get_centerline( image, param=ParamCenterline(algo_fitting='linear', smooth=0, minmax=False)) data_ctl_RPI.change_orientation('SAL') index_RL = np.argwhere(data_ctl_RPI.data) return [index_RL[i][2] for i in range(len(index_RL))]
def get_center(self, img_idx=-1): image = self._images[img_idx] dim = self.get_dim(image) size_y = self.axial_dim(image) size_x = self.coronal_dim(image) return np.ones(dim) * size_x / 2, np.ones(dim) * size_y / 2