Source code for spinalcordtoolbox.centerline.core

#!/usr/bin/env python
# -*- coding: utf-8
# Core functions dealing with centerline extraction from 3D data.


import logging
import numpy as np

from spinalcordtoolbox.image import Image, zeros_like
from spinalcordtoolbox.centerline import curve_fitting

logger = logging.getLogger(__name__)


[docs]class ParamCenterline: """Default parameters for centerline fitting""" def __init__(self, algo_fitting='bspline', degree=5, smooth=20, contrast=None, minmax=True): """ :param algo_fitting: str: polyfit: Polynomial fitting\ bspline: b-spline fitting\ linear: linear interpolation followed by smoothing with Hanning window\ nurbs: Non-Uniform Rational Basis Splines. See [De Leener et al., MIA, 2016]\ optic: Automatic segmentation using SVM and HOG. See [Gros et al., MIA 2018].\ :param degree: Degree of polynomial function. Only for polyfit. :param smooth: Degree of smoothness. Only for bspline and linear. When using algo_fitting=linear, it corresponds :param contrast: Contrast type for algo_fitting=optic. :param minmax: Crop output centerline where the segmentation starts/end. If False, centerline will span all\ slices to the size of a Hanning window (in mm). """ self.algo_fitting = algo_fitting self.contrast = contrast self.degree = degree self.smooth = smooth self.minmax = minmax
[docs]class FitResults: """ Collection of metrics to assess fitting performance """ def __init__(self): class Data: """Raw and fitted data""" def __init__(self): self.xmean = None self.xfit = None self.ymean = None self.yfit = None self.zmean = None self.zref = None self.rmse = None # RMSE self.laplacian_max = None # Maximum of 2nd derivatives self.data = Data() # Raw and fitted data (for plotting in QC report) self.param = None # ParamCenterline()
[docs]def find_and_sort_coord(img): """ Find x,y,z coordinate of centerline and output an array which is sorted along SI direction. Removes any duplicate along the SI direction by averaging across the same ind_SI. :param img: Image(): Input image. Could be any orientation. :return: nx3 numpy array with X, Y, Z coordinates of center of mass """ # TODO: deal with nan, etc. # Get indices of non-null values arr = np.array(np.where(img.data)) # Sort indices according to SI axis dim_si = [img.orientation.find(x) for x in ['I', 'S'] if img.orientation.find(x) != -1][0] # Loop across SI axis and average coordinates within duplicate SI values arr_sorted_avg = [np.array([])] * 3 for i_si in sorted(set(arr[dim_si])): # Get indices corresponding to i_si ind_si_all = np.where(arr[dim_si] == i_si) if len(ind_si_all[0]): # loop across dimensions and average all existing coordinates (equivalent to center of mass) for i_dim in range(3): arr_sorted_avg[i_dim] = np.append(arr_sorted_avg[i_dim], arr[i_dim][ind_si_all].mean()) return np.array(arr_sorted_avg)
[docs]def get_centerline(im_seg, param=ParamCenterline(), verbose=1): """ Extract centerline from an image (using optic) or from a binary or weighted segmentation (using the center of mass). :param im_seg: Image(): Input segmentation or series of points along the centerline. :param param: ParamCenterline() class: :param verbose: int: verbose level :return: im_centerline: Image: Centerline in discrete coordinate (int) :return: arr_centerline: 3x1 array: Centerline in continuous coordinate (float) for each slice in RPI orientation. :return: arr_centerline_deriv: 3x1 array: Derivatives of x and y centerline wrt. z for each slice in RPI orient. :return: fit_results: FitResults class """ if not isinstance(im_seg, Image): raise ValueError("Expecting an image") # Open image and change to RPI orientation native_orientation = im_seg.orientation im_seg.change_orientation('RPI') px, py, pz = im_seg.dim[4:7] # Take the center of mass at each slice to avoid: https://stackoverflow.com/questions/2009379/interpolate-question x_mean, y_mean, z_mean = find_and_sort_coord(im_seg) # Crop output centerline to where the segmentation starts/end if param.minmax: z_ref = np.array(range(z_mean.min().astype(int), z_mean.max().astype(int) + 1)) else: z_ref = np.array(range(im_seg.dim[2])) index_mean = np.array([list(z_ref).index(i) for i in z_mean]) # Choose method if param.algo_fitting == 'polyfit': x_centerline_fit, x_centerline_deriv = curve_fitting.polyfit_1d(z_mean, x_mean, z_ref, deg=param.degree) y_centerline_fit, y_centerline_deriv = curve_fitting.polyfit_1d(z_mean, y_mean, z_ref, deg=param.degree) fig_title = 'Algo={}, Deg={}'.format(param.algo_fitting, param.degree) elif param.algo_fitting == 'bspline': x_centerline_fit, x_centerline_deriv = curve_fitting.bspline(z_mean, x_mean, z_ref, param.smooth, pz=pz) y_centerline_fit, y_centerline_deriv = curve_fitting.bspline(z_mean, y_mean, z_ref, param.smooth, pz=pz) fig_title = 'Algo={}, Smooth={}'.format(param.algo_fitting, param.smooth) elif param.algo_fitting == 'linear': # Simple linear interpolation x_centerline_fit, x_centerline_deriv = curve_fitting.linear(z_mean, x_mean, z_ref, param.smooth, pz=pz) y_centerline_fit, y_centerline_deriv = curve_fitting.linear(z_mean, y_mean, z_ref, param.smooth, pz=pz) fig_title = 'Algo={}, Smooth={}'.format(param.algo_fitting, param.smooth) elif param.algo_fitting == 'nurbs': from spinalcordtoolbox.centerline.nurbs import b_spline_nurbs point_number = 3000 # Interpolate such that the output centerline has the same length as z_ref x_mean_interp, _ = curve_fitting.linear(z_mean, x_mean, z_ref, 0) y_mean_interp, _ = curve_fitting.linear(z_mean, y_mean, z_ref, 0) x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, \ z_centerline_deriv, error = b_spline_nurbs(x_mean_interp, y_mean_interp, z_ref, nbControl=None, point_number=point_number, all_slices=True) # Normalize derivatives to z_deriv x_centerline_deriv = x_centerline_deriv / z_centerline_deriv y_centerline_deriv = y_centerline_deriv / z_centerline_deriv fig_title = 'Algo={}, NumberPoints={}'.format(param.algo_fitting, point_number) elif param.algo_fitting == 'optic': # This method is particular compared to the previous ones, as here we estimate the centerline based on the # image itself (not the segmentation). Hence, we can bypass the fitting procedure and centerline creation # and directly output results. from spinalcordtoolbox.centerline import optic assert param.contrast is not None im_centerline = optic.detect_centerline(im_seg, param.contrast, verbose) x_centerline_fit, y_centerline_fit, z_centerline = find_and_sort_coord(im_centerline) # Compute derivatives using polynomial fit # TODO: Fix below with reorientation of axes _, x_centerline_deriv = curve_fitting.polyfit_1d(z_centerline, x_centerline_fit, z_centerline, deg=param.degree) _, y_centerline_deriv = curve_fitting.polyfit_1d(z_centerline, y_centerline_fit, z_centerline, deg=param.degree) return \ im_centerline.change_orientation(native_orientation), \ np.array([x_centerline_fit, y_centerline_fit, z_centerline]), \ np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_centerline)]), \ None else: logger.error('algo_fitting "' + param.algo_fitting + '" does not exist.') raise ValueError # Create an image with the centerline im_centerline = im_seg.copy() im_centerline.data = np.zeros(im_centerline.data.shape) # Assign value=1 to centerline. Make sure to clip to avoid array overflow. # TODO: check this round and clip-- suspicious im_centerline.data[round_and_clip(x_centerline_fit, clip=[0, im_centerline.data.shape[0]]), round_and_clip(y_centerline_fit, clip=[0, im_centerline.data.shape[1]]), z_ref] = 1 # reorient centerline to native orientation im_centerline.change_orientation(native_orientation) im_seg.change_orientation(native_orientation) # TODO: Reorient centerline in native orientation. For now, we output the array in RPI. Note that it is tricky to # reorient in native orientation, because the voxel center is not in the middle, but in the top corner, so this # needs to be taken into accound during reorientation. The code below does not work properly. # # Get a permutation and inversion based on native orientation # perm, inversion = _get_permutations(im_seg.orientation, native_orientation) # # axes inversion (flip) # # ctl = np.array([x_centerline_fit[::inversion[0]], y_centerline_fit[::inversion[1]], z_ref[::inversion[2]]]) # ctl = np.array([x_centerline_fit, y_centerline_fit, z_ref]) # ctl_deriv = np.array([x_centerline_deriv[::inversion[0]], y_centerline_deriv[::inversion[1]], np.ones_like(z_ref)]) # return im_centerline, \ # np.array([ctl[perm[0]], ctl[perm[1]], ctl[perm[2]]]), \ # np.array([ctl_deriv[perm[0]], ctl_deriv[perm[1]], ctl_deriv[perm[2]]]) # Compute fitting metrics fit_results = FitResults() fit_results.rmse = np.sqrt(np.mean((x_mean - x_centerline_fit[index_mean]) ** 2) * px + np.mean((y_mean - y_centerline_fit[index_mean]) ** 2) * py) fit_results.laplacian_max = np.max([ np.absolute(np.gradient(np.array(x_centerline_deriv * px))).max(), np.absolute(np.gradient(np.array(y_centerline_deriv * py))).max()]) fit_results.data.zmean = z_mean fit_results.data.zref = z_ref fit_results.data.xmean = x_mean fit_results.data.xfit = x_centerline_fit fit_results.data.ymean = y_mean fit_results.data.yfit = y_centerline_fit fit_results.param = param # Display fig of fitted curves if verbose == 2: from datetime import datetime import matplotlib matplotlib.use('Agg') # prevent display figure import matplotlib.pyplot as plt plt.figure(figsize=(16, 10)) plt.subplot(3, 1, 1) plt.title(fig_title + '\nRMSE[mm]={:0.2f}, LaplacianMax={:0.2f}'.format(fit_results.rmse, fit_results.laplacian_max)) plt.plot(z_mean * pz, x_mean * px, 'ro') plt.plot(z_ref * pz, x_centerline_fit * px, 'k') plt.plot(z_ref * pz, x_centerline_fit * px, 'k.') plt.ylabel("X [mm]") plt.legend(['Reference', 'Fitting', 'Fitting points']) plt.subplot(3, 1, 2) plt.plot(z_mean * pz, y_mean * py, 'ro') plt.plot(z_ref * pz, y_centerline_fit * py, 'b') plt.plot(z_ref * pz, y_centerline_fit * py, 'b.') plt.xlabel("Z [mm]") plt.ylabel("Y [mm]") plt.legend(['Reference', 'Fitting', 'Fitting points']) plt.subplot(3, 1, 3) plt.plot(z_ref * pz, x_centerline_deriv * px, 'k.') plt.plot(z_ref * pz, y_centerline_deriv * py, 'b.') plt.grid(axis='y', color='grey', linestyle=':', linewidth=1) plt.axhline(color='grey', linestyle='-', linewidth=1) # plt.plot(z_ref * pz, z_centerline_deriv * pz, 'r.') plt.ylabel("dX/dZ, dY/dZ") plt.xlabel("Z [mm]") plt.legend(['X-deriv', 'Y-deriv']) plt.savefig('fig_centerline_' + datetime.now().strftime("%y%m%d-%H%M%S%f") + '_' + param.algo_fitting + '.png') plt.close() return im_centerline, \ np.array([x_centerline_fit, y_centerline_fit, z_ref]), \ np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_ref)]), \ fit_results
[docs]def round_and_clip(arr, clip=None): """ FIXME doc Round to closest int, convert to dtype=int and clip to min/max values allowed by the list length :param arr: :param clip: [min, max]: Clip values in arr to min and max :return: """ if clip: return np.clip(arr.round().astype(int), clip[0], clip[1]-1) else: return arr.round().astype(int)
def _call_viewer_centerline(im_data, interslice_gap=20.0): # TODO: _call_viewer_centerline should not be "internal" anymore, i.e., remove the "_" """ FIXME doc Call Qt viewer for manually selecting labels. :param im_data: :param interslice_gap: :return: Image() of labels. """ from spinalcordtoolbox.gui.base import AnatomicalParams from spinalcordtoolbox.gui.centerline import launch_centerline_dialog if not isinstance(im_data, Image): raise ValueError("Expecting an image") # Get the number of slice along the (IS) axis im_tmp = im_data.copy().change_orientation('RPI') _, _, nz, _, _, _, pz, _ = im_tmp.dim del im_tmp params = AnatomicalParams() # setting maximum number of points to a reasonable value params.num_points = np.ceil(nz * pz / interslice_gap) + 2 params.interval_in_mm = interslice_gap # Starting at the top slice minus 1 in cases where the first slice is almost zero, due to gradient # non-linearity correction: # https://forum.spinalcordmri.org/t/centerline-viewer-enhancements-sct-v5-0-1/605/4?u=jcohenadad params.starting_slice = 'top_minus_one' im_mask_viewer = zeros_like(im_data) launch_centerline_dialog(im_data, im_mask_viewer, params) return im_mask_viewer