Source code for spinalcordtoolbox.deepseg_sc.postprocessing

#!/usr/bin/env python
# -*- coding: utf-8
# Deals with postprocessing on generated segmentation: remove outliers, fill holes, etc.


import logging
import numpy as np
from scipy.ndimage.measurements import label
from scipy.ndimage.morphology import binary_fill_holes

from ..process_seg import compute_shape


logger = logging.getLogger(__name__)


def _fill_z_holes(zz_lst, data, z_spaccing):
    data_interpol = np.copy(data)

    for z_hole_start, z_hole_end in list(_list2range(zz_lst)):
        z_ref_start, z_ref_end = z_hole_start - 1, z_hole_end
        slice_ref_start, slice_ref_end = data[:, :, z_ref_start], data[:, :, z_ref_end]

        hole_cur_lst = list(range(z_hole_start, z_hole_end))
        lenght_hole = len(hole_cur_lst) + 1
        phys_lenght_hole = lenght_hole * z_spaccing

        denom_interpolation = (lenght_hole + 1)

        if phys_lenght_hole < 10:
            logger.warning('Filling a hole in the segmentation around z_slice #:' + str(z_ref_start))

            for idx_z, z_hole_cur in enumerate(hole_cur_lst):
                num_interpolation = (lenght_hole - idx_z - 1) * slice_ref_start  # Contribution of the bottom ref slice
                num_interpolation += (idx_z + 1) * slice_ref_end  # Contribution of the top ref slice

                slice_interpolation = num_interpolation * 1. / denom_interpolation
                slice_interpolation = (slice_interpolation > 0).astype(np.int)

                data_interpol[:, :, z_hole_cur] = slice_interpolation

    return data_interpol


def _list2range(lst):
    tmplst = lst[:]
    tmplst.sort()
    start = tmplst[0]

    currentrange = [start, start + 1]

    for item in tmplst[1:]:
        if currentrange[1] == item:  # contiguous
            currentrange[1] += 1
        else:  # new range start
            yield list(currentrange)
            currentrange = [item, item + 1]

    yield list(currentrange)  # last range


def _remove_blobs(data):
    """Remove false positive blobs, likely occuring in brain sections."""
    labeled_obj, num_obj = label(data)
    if num_obj > 1:  # If there is more than one connected object
        bigger_obj = (labeled_obj == (np.bincount(labeled_obj.flat)[1:].argmax() + 1))

        data2clean = np.copy(data)

        # remove blobs only above the bigger connected object
        z_max = np.max(np.where(bigger_obj)[2])
        data2clean[:, :, :z_max + 1] = 0

        labeled_obj2clean, num_obj2clean = label(data2clean)
        if num_obj2clean:  # If there is connected object above the biffer connected one
            for obj_id in range(1, num_obj2clean + 1):
                # if the blob has a volume < 10% of the bigger connected object, then remove it
                if np.sum(labeled_obj2clean == obj_id) < 0.1 * np.sum(bigger_obj):
                    logger.warning('Removing small objects above slice #' + str(z_max))
                    data[np.where(labeled_obj2clean == obj_id)] = 0

    return data


def _remove_extrem_holes(z_lst, end_z, start_z=0):
    """Remove extrem holes from the list of holes so that we will not interpolate on the extrem slices."""
    if start_z in z_lst:
        while start_z in z_lst:
            z_lst = z_lst[1:]
            start_z += 1
        if len(z_lst):
            z_lst.pop(0)

    if end_z in z_lst:
        while end_z in z_lst:
            z_lst = z_lst[:-1]
            end_z -= 1

    return z_lst


def _remove_isolated_voxels_on_the_edge(im_seg, n_slices=5):
    """
    Remove isolated voxels on the edge if the CSA of the edge slice is smaller than half the median of adjacent slices.

    :param im_seg:
    :param n_slices: Number of adjacent slices to consider. If not enough slices, this test will be bypassed.
    :return:
    """
    # Compute shape info across segmentation
    metrics, _ = compute_shape(im_seg, angle_correction=False)
    # Extract CSA and get min/max index, corresponding to the top/bottom edges of the segmentation
    ind_nonnan = np.where(np.isnan(metrics['area'].data) == False)[0]
    ind_min, ind_max = ind_nonnan[0], ind_nonnan[-1]
    # Check if the CSA at the edge is inferior to half of the median across adjacent slices...
    # ... for the top slice
    if metrics['area'].data[ind_min] < np.median(metrics['area'].data[ind_min:n_slices])/2:
        im_seg.data[:, :, ind_min] = 0
        logger.warning('Found isolated voxels on slice {}, Removing them'.format(ind_min))
    # ... for the bottom slice
    if metrics['area'].data[ind_max] < np.median(metrics['area'].data[ind_max-n_slices+1:ind_max+1])/2:
        im_seg.data[:, :, ind_max] = 0
        logger.warning('Found isolated voxels on slice {}, Removing them'.format(ind_min))
    return im_seg


[docs]def post_processing_volume_wise(im_seg): """ Post processing function to clean the input segmentation: fill holes, remove edge outlier, etc. Note: This function is compatible with soft segmentation (i.e. float between 0-1). """ data_bin = (im_seg.data > 0).astype(np.int) # will binarize soft segmentation # Remove blobs data_bin = _remove_blobs(data_bin) # Fill z_holes, i.e. interpolate for z_slice not segmented zz_zeros = [zz for zz in range(im_seg.dim[2]) if 1 not in list(np.unique(data_bin[:, :, zz]))] zz_holes = _remove_extrem_holes(zz_zeros, im_seg.dim[2] - 1, 0) data_pp = _fill_z_holes(zz_holes, data_bin, im_seg.dim[6]) if len(zz_holes) else data_bin im_seg.data[np.where(data_pp == 0)] = 0 # to be compatible with soft segmentation # Set isolated voxels at edge slices to zero im_seg = _remove_isolated_voxels_on_the_edge(im_seg) return im_seg
[docs]def keep_largest_object(z_slice_bin, x_cOm, y_cOm): """ Keep the largest connected object per z_slice and fill little holes. Note: This function only works for binary segmentation. :param z_slice: int 2d-array: Input 2d segmentation :param x_cOm: int: X center of mass of the segmentation for the previous 2d slice :param y_cOm: int: Y center of mass of the segmentation for the previous 2d slice :return: z_slice: int 2d-array: Processed 2d segmentation """ assert z_slice_bin.dtype == np.dtype('int') # Find number of closed objects using skimage "label" labeled_obj, num_obj = label(z_slice_bin) # If more than one object is found, keep the largest one if num_obj > 1: # If the center of mass is not provided (e.g. is first slice, or segmentation is empty), keep the largest object if x_cOm is None or np.isnan(x_cOm): z_slice_bin[np.where(labeled_obj != (np.bincount(labeled_obj.flat)[1:].argmax() + 1))] = 0 # If the center of mass is provided, else: idx_z_minus_1 = np.bincount(labeled_obj.flat)[1:].argmax() + 1 for idx in range(1, num_obj + 1): z_idx = labeled_obj == idx if z_idx[int(x_cOm), int(y_cOm)]: idx_z_minus_1 = idx z_slice_bin[np.where(labeled_obj != idx_z_minus_1)] = 0 return z_slice_bin
[docs]def fill_holes_2d(z_slice): """ Fill holes in the segmentation. :param z_slice: int 2d-array: Input 2D segmentation. :return: int 2d-array: Output segmentation with holes filled """ assert z_slice.dtype == np.dtype('int') return binary_fill_holes(z_slice, structure=np.ones((3, 3))).astype(np.int)