Source code for spinalcordtoolbox.vertebrae.core

#!/usr/bin/env python
# -*- coding: utf-8
# Core functions dealing with vertebral labeling

# TODO: remove i/o as much as possible

import os
import logging

import numpy as np
import scipy.ndimage.measurements
from scipy.ndimage.filters import gaussian_filter
from scipy.signal import gaussian

from spinalcordtoolbox.image import Image, add_suffix
from spinalcordtoolbox.metadata import get_file_label
from spinalcordtoolbox.math import dilate, mutual_information

logger = logging.getLogger(__name__)


[docs]def label_vert(fname_seg, fname_label, verbose=1): """ Label segmentation using vertebral labeling information. No orientation expected. :param fname_seg: file name of segmentation. :param fname_label: file name for a labelled segmentation that will be used to label the input segmentation :param fname_out: file name of the output labeled segmentation. If empty, will add suffix "_labeled" to fname_seg :param verbose: :return: """ # Open labels im_disc = Image(fname_label).change_orientation("RPI") # retrieve all labels coord_label = im_disc.getNonZeroCoordinates() # compute list_disc_z and list_disc_value list_disc_z = [] list_disc_value = [] for i in range(len(coord_label)): list_disc_z.insert(0, coord_label[i].z) # '-1' to use the convention "disc labelvalue=3 ==> disc C2/C3" list_disc_value.insert(0, coord_label[i].value - 1) list_disc_value = [x for (y, x) in sorted(zip(list_disc_z, list_disc_value), reverse=True)] list_disc_z = [y for (y, x) in sorted(zip(list_disc_z, list_disc_value), reverse=True)] # label segmentation label_segmentation(fname_seg, list_disc_z, list_disc_value, verbose=verbose) label_discs(fname_seg, list_disc_z, list_disc_value, verbose=verbose)
[docs]def vertebral_detection(fname, fname_seg, contrast, param, init_disc, verbose=1, path_template='', path_output='../', scale_dist=1.): """ Find intervertebral discs in straightened image using template matching :param fname: file name of straigthened spinal cord :param fname_seg: file name of straigthened spinal cord segmentation :param contrast: t1 or t2 :param param: advanced parameters :param init_disc: :param verbose: :param path_template: :param path_output: output path for verbose=2 pictures :param scale_dist: float: Scaling factor to adjust average distance between two adjacent intervertebral discs :return: """ logger.info('Look for template...') logger.info('Path template: %s', path_template) # adjust file names if MNI-Poly-AMU template is used (by default: PAM50) fname_level = get_file_label(os.path.join(path_template, 'template'), id_label=7, output='filewithpath') # label = spinal cord mask with discrete vertebral levels id_label_dct = {'T1': 0, 'T2': 1, 'T2S': 2} fname_template = get_file_label(os.path.join(path_template, 'template'), id_label=id_label_dct[contrast.upper()], output='filewithpath') # label = *-weighted template # Open template and vertebral levels logger.info('Open template and vertebral levels...') data_template = Image(fname_template).data data_disc_template = Image(fname_level).data # open anatomical volume im_input = Image(fname) data = im_input.data # smooth data data = gaussian_filter(data, param.smooth_factor, output=None, mode="reflect") # get dimension of src nx, ny, nz = data.shape # define xc and yc (centered in the field of view) xc = int(np.round(nx / 2)) # direction RL yc = int(np.round(ny / 2)) # direction AP # get dimension of template nxt, nyt, nzt = data_template.shape # define xc and yc (centered in the field of view) xct = int(np.round(nxt / 2)) # direction RL yct = int(np.round(nyt / 2)) # direction AP # define mean distance (in voxel) between adjacent discs: [C1/C2 -> C2/C3], [C2/C3 -> C4/C5], ..., [L1/L2 -> L2/L3] centerline_level = data_disc_template[xct, yct, :] # attribute value to each disc. Starts from max level, then decrease. min_level = centerline_level[centerline_level.nonzero()].min() max_level = centerline_level[centerline_level.nonzero()].max() list_disc_value_template = list(range(min_level, max_level)) # add disc above top one list_disc_value_template.insert(int(0), min_level - 1) logger.info('Disc values from template: %s', list_disc_value_template) # get diff to find transitions (i.e., discs) diff_centerline_level = np.diff(centerline_level) # get disc z-values list_disc_z_template = diff_centerline_level.nonzero()[0].tolist() list_disc_z_template.reverse() logger.info('Z-values for each disc: %s', list_disc_z_template) list_distance_template = ( np.diff(list_disc_z_template) * (-1)).tolist() # multiplies by -1 to get positive distances # Update distance with scaling factor list_distance_template = [i * scale_dist for i in list_distance_template] logger.info('Distances between discs (in voxel): %s', list_distance_template) # display init disc if verbose == 2: from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure fig_disc = Figure() FigureCanvas(fig_disc) ax_disc = fig_disc.add_subplot(111) # ax_disc = fig_disc.add_axes((0, 0, 1, 1)) # get percentile for automatic contrast adjustment data_display = np.mean(data[xc - param.size_RL:xc + param.size_RL, :, :], axis=0).transpose() percmin = np.percentile(data_display, 10) percmax = np.percentile(data_display, 90) # display image ax_disc.matshow(data_display, cmap='gray', clim=[percmin, percmax], origin='lower') ax_disc.set_title('Anatomical image') # ax.autoscale(enable=False) # to prevent autoscale of axis when displaying plot ax_disc.scatter(yc + param.shift_AP_visu, init_disc[0], c='yellow', s=10) ax_disc.text(yc + param.shift_AP_visu + 4, init_disc[0], str(init_disc[1]) + '/' + str(init_disc[1] + 1), verticalalignment='center', horizontalalignment='left', color='pink', fontsize=7) # FIND DISCS # =========================================================================== logger.info('Detect intervertebral discs...') # assign initial z and disc current_z = init_disc[0] current_disc = init_disc[1] # create list for z and disc list_disc_z = [] list_disc_value = [] zrange = list(range(-10, 10)) direction = 'superior' search_next_disc = True while search_next_disc: logger.info('Current disc: %s (z=%s). Direction: %s', current_disc, current_z, direction) try: # get z corresponding to current disc on template current_z_template = list_disc_z_template[current_disc] except: # in case reached the bottom (see issue #849) logger.warning('Reached the bottom of the template. Stop searching.') break # find next disc # N.B. Do not search for C1/C2 disc (because poorly visible), use template distance instead if current_disc != 1: current_z = compute_corr_3d(data, data_template, x=xc, xshift=0, xsize=param.size_RL, y=yc, yshift=param.shift_AP, ysize=param.size_AP, z=current_z, zshift=0, zsize=param.size_IS, xtarget=xct, ytarget=yct, ztarget=current_z_template, zrange=zrange, verbose=verbose, save_suffix='_disc' + str(current_disc), gaussian_std=999, path_output=path_output) # display new disc if verbose == 2: ax_disc.scatter(yc + param.shift_AP_visu, current_z, c='yellow', s=10) ax_disc.text(yc + param.shift_AP_visu + 4, current_z, str(current_disc) + '/' + str(current_disc + 1), verticalalignment='center', horizontalalignment='left', color='yellow', fontsize=7) # append to main list if direction == 'superior': # append at the beginning list_disc_z.insert(0, current_z) list_disc_value.insert(0, current_disc) elif direction == 'inferior': # append at the end list_disc_z.append(current_z) list_disc_value.append(current_disc) # adjust correcting factor based on already-identified discs if len(list_disc_z) > 1: # compute distance between already-identified discs list_distance_current = (np.diff(list_disc_z) * (-1)).tolist() # retrieve the template distance corresponding to the already-identified discs index_disc_identified = [i for i, j in enumerate(list_disc_value_template) if j in list_disc_value[:-1]] list_distance_template_identified = [list_distance_template[i] for i in index_disc_identified] # divide subject and template distances for the identified discs list_subject_to_template_distance = [float(list_distance_current[i]) / list_distance_template_identified[i] for i in range(len(list_distance_current))] # average across identified discs to obtain an average correcting factor correcting_factor = np.mean(list_subject_to_template_distance) logger.info('.. correcting factor: %s', correcting_factor) else: correcting_factor = 1 # update list_distance specific for the subject list_distance = [int(np.round(list_distance_template[i] * correcting_factor)) for i in range(len(list_distance_template))] # assign new current_z and disc value if direction == 'superior': try: approx_distance_to_next_disc = list_distance[list_disc_value_template.index(current_disc - 1)] except ValueError: logger.warning('Disc value not included in template. Using previously-calculated distance: %s', approx_distance_to_next_disc) # assign new current_z and disc value current_z = current_z + approx_distance_to_next_disc current_disc = current_disc - 1 elif direction == 'inferior': try: approx_distance_to_next_disc = list_distance[list_disc_value_template.index(current_disc)] except ValueError: logger.warning('Disc value not included in template. Using previously-calculated distance: %s', approx_distance_to_next_disc) # assign new current_z and disc value current_z = current_z - approx_distance_to_next_disc current_disc = current_disc + 1 # if current_z is larger than searching zone, switch direction (and start from initial z minus approximate # distance from updated template distance) if current_z >= nz or current_disc == 0: logger.info('.. Switching to inferior direction.') direction = 'inferior' current_disc = init_disc[1] + 1 current_z = init_disc[0] - list_distance[list_disc_value_template.index(current_disc)] # if current_z is lower than searching zone, stop searching if current_z <= 0: search_next_disc = False if verbose == 2: fig_disc.savefig('fig_label_discs.png') # if upper disc is not 1, add disc above top disc based on mean_distance_adjusted upper_disc = min(list_disc_value) # if not upper_disc == 1: logger.info('Adding top disc based on adjusted template distance: #%s', upper_disc - 1) approx_distance_to_next_disc = list_distance[list_disc_value_template.index(upper_disc - 1)] next_z = max(list_disc_z) + approx_distance_to_next_disc logger.info('.. approximate distance: %s', approx_distance_to_next_disc) # make sure next disc does not go beyond FOV in superior direction if next_z > nz: list_disc_z.insert(0, nz) else: list_disc_z.insert(0, next_z) # assign disc value list_disc_value.insert(0, upper_disc - 1) # Label segmentation label_segmentation(fname_seg, list_disc_z, list_disc_value, verbose=verbose) label_discs(fname_seg, list_disc_z, list_disc_value, verbose=verbose)
[docs]def center_of_mass(x): """ :return: array center of mass """ if (x == 0).all(): raise ValueError("Array has no mass") return scipy.ndimage.measurements.center_of_mass(x)
[docs]def create_label_z(fname_seg, z, value, fname_labelz='labelz.nii.gz'): """ Create a label at coordinates x_center, y_center, z :param fname_seg: segmentation :param z: int :param fname_labelz: string file name of output label :return: fname_labelz """ nii = Image(fname_seg) orientation_origin = nii.orientation nii = nii.change_orientation("RPI") nx, ny, nz, nt, px, py, pz, pt = nii.dim # Get dimensions # find x and y coordinates of the centerline at z using center of mass x, y = center_of_mass(np.array(nii.data[:, :, z])) x, y = int(np.round(x)), int(np.round(y)) nii.data[:, :, :] = 0 nii.data[x, y, z] = value # dilate label to prevent it from disappearing due to nearestneighbor interpolation nii.data = dilate(nii.data, 3, 'ball') nii.change_orientation(orientation_origin) # put back in original orientation nii.save(fname_labelz) return fname_labelz
[docs]def get_z_and_disc_values_from_label(fname_label): """ Find z-value and label-value based on labeled image in RPI orientation :param fname_label: image in RPI orientation that contains label :return: [z_label, value_label] int list """ nii = Image(fname_label) # get center of mass of label x_label, y_label, z_label = center_of_mass(nii.data) x_label, y_label, z_label = int(np.round(x_label)), int(np.round(y_label)), int(np.round(z_label)) # get label value value_label = int(nii.data[x_label, y_label, z_label]) return [z_label, value_label]
[docs]def clean_labeled_segmentation(fname_labeled_seg, fname_seg, fname_labeled_seg_new): """ FIXME doc Clean labeled segmentation by:\ (i) removing voxels in segmentation_labeled that are not in segmentation and\ (ii) adding voxels in segmentation that are not in segmentation_labeled :param fname_labeled_seg: :param fname_seg: :param fname_labeled_seg_new: output :return: none """ # remove voxels in segmentation_labeled that are not in segmentation img_labeled_seg = Image(fname_labeled_seg) img_seg = Image(fname_seg) data_labeled_seg_mul = img_labeled_seg.data * img_seg.data # dilate to add voxels in segmentation that are not in segmentation_labeled data_labeled_seg_dil = dilate(img_labeled_seg.data, 2, 'ball') data_labeled_seg_mul_bin = data_labeled_seg_mul > 0 data_diff = img_seg.data - data_labeled_seg_mul_bin ind_nonzero = np.where(data_diff) img_labeled_seg_corr = img_labeled_seg.copy() img_labeled_seg_corr.data = data_labeled_seg_mul for i_vox in range(len(ind_nonzero[0])): # assign closest label value for this voxel ix, iy, iz = ind_nonzero[0][i_vox], ind_nonzero[1][i_vox], ind_nonzero[2][i_vox] img_labeled_seg_corr.data[ix, iy, iz] = data_labeled_seg_dil[ix, iy, iz] # save new label file (overwrite) img_labeled_seg_corr.absolutepath = fname_labeled_seg_new img_labeled_seg_corr.save()
[docs]def compute_corr_3d(src, target, x, xshift, xsize, y, yshift, ysize, z, zshift, zsize, xtarget, ytarget, ztarget, zrange, verbose, save_suffix, gaussian_std, path_output): """ FIXME doc Find z that maximizes correlation between src and target 3d data. :param src: 3d source data :param target: 3d target data :param x: :param xshift: :param xsize: :param y: :param yshift: :param ysize: :param z: :param zshift: :param zsize: :param xtarget: :param ytarget: :param ztarget: :param zrange: :param verbose: :param save_suffix: :param gaussian_std: :return: """ # parameters thr_corr = 0.2 # disc correlation threshold. Below this value, use template distance. # get dimensions from src nx, ny, nz = src.shape # Get pattern from template pattern = target[xtarget - xsize: xtarget + xsize + 1, ytarget + yshift - ysize: ytarget + yshift + ysize + 1, ztarget + zshift - zsize: ztarget + zshift + zsize + 1] pattern1d = pattern.ravel() # initializations I_corr = np.zeros(len(zrange)) allzeros = 0 # current_z = 0 ind_I = 0 # loop across range of z defined by src for iz in zrange: # if pattern extends towards the top part of the image, then crop and pad with zeros if z + iz + zsize + 1 > nz: padding_size = z + iz + zsize + 1 - nz data_chunk3d = src[x - xsize: x + xsize + 1, y + yshift - ysize: y + yshift + ysize + 1, z + iz - zsize: z + iz + zsize + 1 - padding_size] data_chunk3d = np.pad(data_chunk3d, ((0, 0), (0, 0), (0, padding_size)), 'constant', constant_values=0) # if pattern extends towards bottom part of the image, then crop and pad with zeros elif z + iz - zsize < 0: padding_size = abs(iz - zsize) data_chunk3d = src[x - xsize: x + xsize + 1, y + yshift - ysize: y + yshift + ysize + 1, z + iz - zsize + padding_size: z + iz + zsize + 1] data_chunk3d = np.pad(data_chunk3d, ((0, 0), (0, 0), (padding_size, 0)), 'constant', constant_values=0) else: data_chunk3d = src[x - xsize: x + xsize + 1, y + yshift - ysize: y + yshift + ysize + 1, z + iz - zsize: z + iz + zsize + 1] # convert subject pattern to 1d data_chunk1d = data_chunk3d.ravel() # check if data_chunk1d contains at least one non-zero value if (data_chunk1d.size == pattern1d.size) and np.any(data_chunk1d): I_corr[ind_I] = mutual_information(data_chunk1d, pattern1d, nbins=16, normalized=False) else: allzeros = 1 ind_I = ind_I + 1 # ind_y = ind_y + 1 if allzeros: logger.warning('Data contained zero. We probably hit the edge of the image.') # adjust correlation with Gaussian function centered at the right edge of the curve (most rostral point of FOV) gaussian_window = gaussian(len(I_corr) * 2, std=len(I_corr) * gaussian_std) I_corr_gauss = np.multiply(I_corr, gaussian_window[0:len(I_corr)]) # Find global maximum if np.any(I_corr_gauss): # if I_corr contains at least a non-zero value ind_peak = [i for i in range(len(I_corr_gauss)) if I_corr_gauss[i] == max(I_corr_gauss)][0] # index of max along z logger.info('.. Peak found: z=%s (correlation = %s)', zrange[ind_peak], I_corr_gauss[ind_peak]) # check if correlation is high enough if I_corr_gauss[ind_peak] < thr_corr: logger.warning('Correlation is too low. Using adjusted template distance.') ind_peak = zrange.index(0) # approx_distance_to_next_disc else: # if I_corr contains only zeros logger.warning('Correlation vector only contains zeros. Using adjusted template distance.') ind_peak = zrange.index(0) # approx_distance_to_next_disc # display patterns and correlation if verbose == 2: from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure fig = Figure(figsize=(15, 7)) FigureCanvas(fig) # display template pattern ax = fig.add_subplot(131) ax.imshow(np.flipud(np.mean(pattern[:, :, :], axis=0).transpose()), origin='upper', cmap='gray', interpolation='none') ax.set_title('Template pattern') # display subject pattern at best z ax = fig.add_subplot(132) iz = zrange[ind_peak] data_chunk3d = src[x - xsize: x + xsize + 1, y + yshift - ysize: y + yshift + ysize + 1, z + iz - zsize: z + iz + zsize + 1] ax.imshow(np.flipud(np.mean(data_chunk3d[:, :, :], axis=0).transpose()), origin='upper', cmap='gray', clim=[0, 800], interpolation='none') ax.set_title('Subject at iz=' + str(iz)) # display correlation curve ax = fig.add_subplot(133) ax.plot(zrange, I_corr) ax.plot(zrange, I_corr_gauss, 'black', linestyle='dashed') ax.legend(['I_corr', 'I_corr_gauss']) ax.set_title('Mutual Info, gaussian_std=' + str(gaussian_std)) ax.plot(zrange[ind_peak], I_corr_gauss[ind_peak], 'ro') ax.axvline(x=zrange.index(0), linewidth=1, color='black', linestyle='dashed') ax.axhline(y=thr_corr, linewidth=1, color='r', linestyle='dashed') ax.grid() # save figure fig.savefig('fig_pattern' + save_suffix + '.png') # return z-origin (z) + z-displacement minus zshift (to account for non-centered disc) return z + zrange[ind_peak] - zshift
[docs]def label_segmentation(fname_seg, list_disc_z, list_disc_value, verbose=1): """ Label segmentation image :param fname_seg: fname of the segmentation, no orientation expected :param list_disc_z: list of z that correspond to a disc :param list_disc_value: list of associated disc values :param verbose: :return: """ # open segmentation seg = Image(fname_seg) init_orientation = seg.orientation seg.change_orientation("RPI") dim = seg.dim ny = dim[1] nz = dim[2] # loop across z for iz in range(nz): # get index of the disc right above iz try: ind_above_iz = max([i for i in range(len(list_disc_z)) if list_disc_z[i] > iz]) except ValueError: # if ind_above_iz is empty, attribute value 0 vertebral_level = 0 else: # assign vertebral level (add one because iz is BELOW the disk) vertebral_level = list_disc_value[ind_above_iz] + 1 # get voxels in mask ind_nonzero = np.nonzero(seg.data[:, :, iz]) seg.data[ind_nonzero[0], ind_nonzero[1], iz] = vertebral_level # write file seg.change_orientation(init_orientation).save(add_suffix(fname_seg, '_labeled'))
[docs]def label_discs(fname_seg, list_disc_z, list_disc_value, verbose=1): """ Create file with single voxel label in the middle of the spinal cord for each disc. :param fname_seg: fname of the segmentation, no orientation expected :param list_disc_z: list of z that correspond to a disc :param list_disc_value: list of associated disc values :param verbose: :return: """ seg = Image(fname_seg) init_orientation = seg.orientation seg.change_orientation("RPI") disc_data = np.zeros_like(seg.data) nx, ny, nz = seg.data.shape for i in range(len(list_disc_z)): if list_disc_z[i] < nz: slices = seg.data[:, :, list_disc_z[i]] cx, cy = [int(x) for x in np.round(center_of_mass(slices)).tolist()] # Disc value are offset by one due to legacy code disc_data[cx, cy, list_disc_z[i]] = list_disc_value[i] + 1 seg.data = disc_data seg.change_orientation(init_orientation).save(add_suffix(fname_seg, '_labeled_disc'))