#!/usr/bin/env python
# -*- coding: utf-8
# Functions dealing with deepseg_sc
import os
import sys
import logging
import numpy as np
from skimage.exposure import rescale_intensity
from scipy.ndimage.measurements import center_of_mass, label
from scipy.ndimage import distance_transform_edt
from spinalcordtoolbox import resampling
from .cnn_models import nn_architecture_seg, nn_architecture_ctr
from .postprocessing import post_processing_volume_wise, keep_largest_object, fill_holes_2d
from spinalcordtoolbox.image import Image, empty_like, change_type, zeros_like, add_suffix, concat_data, split_img_data
from spinalcordtoolbox.centerline.core import ParamCenterline, get_centerline, _call_viewer_centerline
from spinalcordtoolbox.utils import sct_dir_local_path, TempFolder
from spinalcordtoolbox.deepseg_sc.cnn_models_3d import load_trained_model
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
BATCH_SIZE = 4
# Thresholds to apply to binarize segmentations from the output of the 2D CNN. These thresholds were obtained by
# minimizing the standard deviation of cross-sectional area across contrasts. For more details, see:
# https://github.com/sct-pipeline/deepseg-threshold
THR_DEEPSEG = {'t1': 0.15, 't2': 0.7, 't2s': 0.89, 'dwi': 0.01}
logger = logging.getLogger(__name__)
[docs]def find_centerline(algo, image_fname, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname):
"""
Assumes RPI orientation
:param algo:
:param image_fname:
:param contrast_type:
:param brain_bool:
:param folder_output:
:param remove_temp_files:
:param centerline_fname:
:return:
"""
im = Image(image_fname)
ctl_absolute_path = add_suffix(im.absolutepath, "_ctr")
# isct_spine_detect requires nz > 1
if im.dim[2] == 1:
im = concat_data([im, im], dim=2)
im.hdr['dim'][3] = 2 # Needs to be change manually since dim not updated during concat_data
bool_2d = True
else:
bool_2d = False
# TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
if algo == 'svm':
# run optic on a heatmap computed by a trained SVM+HoG algorithm
# optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
# # TODO: replace with get_centerline(method=optic)
im_ctl, _, _, _ = get_centerline(im,
ParamCenterline(algo_fitting='optic', contrast=contrast_type))
elif algo == 'cnn':
# CNN parameters
dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
dct_params_ctr = {'t2': {'features': 16, 'dilation_layers': 2},
't2s': {'features': 8, 'dilation_layers': 3},
't1': {'features': 24, 'dilation_layers': 3},
'dwi': {'features': 8, 'dilation_layers': 2}}
# load model
ctr_model_fname = sct_dir_local_path('data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type))
ctr_model = nn_architecture_ctr(height=dct_patch_ctr[contrast_type]['size'][0],
width=dct_patch_ctr[contrast_type]['size'][1],
channels=1,
classes=1,
features=dct_params_ctr[contrast_type]['features'],
depth=2,
temperature=1.0,
padding='same',
batchnorm=True,
dropout=0.0,
dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
ctr_model.load_weights(ctr_model_fname)
# compute the heatmap
im_heatmap, z_max = heatmap(im=im,
model=ctr_model,
patch_shape=dct_patch_ctr[contrast_type]['size'],
mean_train=dct_patch_ctr[contrast_type]['mean'],
std_train=dct_patch_ctr[contrast_type]['std'],
brain_bool=brain_bool)
im_ctl, _, _, _ = get_centerline(im_heatmap,
ParamCenterline(algo_fitting='optic', contrast=contrast_type))
if z_max is not None:
logger.info('Cropping brain section.')
im_ctl.data[:, :, z_max:] = 0
elif algo == 'viewer':
im_labels = _call_viewer_centerline(im)
im_ctl, _, _, _ = get_centerline(im_labels, param=ParamCenterline())
elif algo == 'file':
im_ctl = Image(centerline_fname)
im_ctl.change_orientation('RPI')
else:
logger.error('The parameter "-centerline" is incorrect. Please try again.')
sys.exit(1)
# TODO: for some reason, when algo == 'file', the absolutepath is changed to None out of the method find_centerline
im_ctl.absolutepath = ctl_absolute_path
if bool_2d:
im_ctl = split_img_data(im_ctl, dim=2)[0]
if algo != 'viewer':
im_labels = None
# TODO: remove unecessary return params
return "dummy_file_name", im_ctl, im_labels
[docs]def scale_intensity(data, out_min=0, out_max=255):
"""Scale intensity of data in a range defined by [out_min, out_max], based on the 2nd and 98th percentiles."""
p2, p98 = np.percentile(data, (2, 98))
return rescale_intensity(data, in_range=(p2, p98), out_range=(out_min, out_max))
[docs]def apply_intensity_normalization(im_in, params=None):
"""Standardize the intensity range."""
img_normalized = im_in.change_type(np.float32)
img_normalized.data = scale_intensity(img_normalized.data)
return img_normalized
def _find_crop_start_end(coord_ctr, crop_size, im_dim):
"""Util function to find the coordinates to crop the image around the centerline (coord_ctr)."""
half_size = crop_size // 2
coord_start, coord_end = int(coord_ctr) - half_size + 1, int(coord_ctr) + half_size + 1
if coord_end > im_dim:
coord_end = im_dim
coord_start = im_dim - crop_size if im_dim >= crop_size else 0
if coord_start < 0:
coord_start = 0
coord_end = crop_size if im_dim >= crop_size else im_dim
return coord_start, coord_end
[docs]def crop_image_around_centerline(im_in, ctr_in, crop_size):
"""Crop the input image around the input centerline file."""
data_ctr = ctr_in.data
data_ctr = data_ctr if len(data_ctr.shape) >= 3 else np.expand_dims(data_ctr, 2)
data_in = im_in.data.astype(np.float32)
im_new = empty_like(im_in) # but in fact we're going to crop it
x_lst, y_lst, z_lst = [], [], []
data_im_new = np.zeros((crop_size, crop_size, im_in.dim[2]))
for zz in range(im_in.dim[2]):
if np.any(np.array(data_ctr[:, :, zz])):
x_ctr, y_ctr = center_of_mass(np.array(data_ctr[:, :, zz]))
x_start, x_end = _find_crop_start_end(x_ctr, crop_size, im_in.dim[0])
y_start, y_end = _find_crop_start_end(y_ctr, crop_size, im_in.dim[1])
crop_im = np.zeros((crop_size, crop_size))
x_shape, y_shape = data_in[x_start:x_end, y_start:y_end, zz].shape
crop_im[:x_shape, :y_shape] = data_in[x_start:x_end, y_start:y_end, zz]
data_im_new[:, :, zz] = crop_im
x_lst.append(str(x_start))
y_lst.append(str(y_start))
z_lst.append(zz)
im_new.data = data_im_new
return x_lst, y_lst, z_lst, im_new
[docs]def scan_slice(z_slice, model, mean_train, std_train, coord_lst, patch_shape, z_out_dim):
"""Scan the entire axial slice to detect the centerline."""
z_slice_out = np.zeros(z_out_dim)
sum_lst = []
# loop across all the non-overlapping blocks of a cross-sectional slice
for idx, coord in enumerate(coord_lst):
block = z_slice[coord[0]:coord[2], coord[1]:coord[3]]
block_nn = np.expand_dims(np.expand_dims(block, 0), -1)
block_nn_norm = _normalize_data(block_nn, mean_train, std_train)
block_pred = model.predict(block_nn_norm, batch_size=BATCH_SIZE)
if coord[2] > z_out_dim[0]:
x_end = patch_shape[0] - (coord[2] - z_out_dim[0])
else:
x_end = patch_shape[0]
if coord[3] > z_out_dim[1]:
y_end = patch_shape[1] - (coord[3] - z_out_dim[1])
else:
y_end = patch_shape[1]
z_slice_out[coord[0]:coord[2], coord[1]:coord[3]] = block_pred[0, :x_end, :y_end, 0]
sum_lst.append(np.sum(block_pred[0, :x_end, :y_end, 0]))
# Put first the coord of the patch were the centerline is likely located so that the search could be faster for the
# next axial slices
coord_lst.insert(0, coord_lst.pop(sum_lst.index(max(sum_lst))))
# computation of the new center of mass
if np.max(z_slice_out) > 0.5:
z_slice_out_bin = z_slice_out > 0.5
labeled_mask, numpatches = label(z_slice_out_bin)
largest_cc_mask = (labeled_mask == (np.bincount(labeled_mask.flat)[1:].argmax() + 1))
x_CoM, y_CoM = center_of_mass(largest_cc_mask)
x_CoM, y_CoM = int(x_CoM), int(y_CoM)
else:
x_CoM, y_CoM = None, None
return z_slice_out, x_CoM, y_CoM, coord_lst
[docs]def heatmap(im, model, patch_shape, mean_train, std_train, brain_bool=True):
"""Compute the heatmap with CNN_1 representing the SC localization."""
data_im = im.data.astype(np.float32)
im_out = change_type(im, "uint8")
del im
data = np.zeros(im_out.data.shape)
x_shape, y_shape = data_im.shape[:2]
x_shape_block, y_shape_block = np.ceil(x_shape * 1.0 / patch_shape[0]).astype(np.int), np.int(
y_shape * 1.0 / patch_shape[1])
x_pad = int(x_shape_block * patch_shape[0] - x_shape)
if y_shape > patch_shape[1]:
y_crop = y_shape - y_shape_block * patch_shape[1]
# slightly crop the input data in the P-A direction so that data_im.shape[1] % patch_shape[1] == 0
data_im = data_im[:, :y_shape - y_crop, :]
# coordinates of the blocks to scan during the detection, in the cross-sectional plane
coord_lst = [[x_dim * patch_shape[0], y_dim * patch_shape[1],
(x_dim + 1) * patch_shape[0], (y_dim + 1) * patch_shape[1]]
for y_dim in range(y_shape_block) for x_dim in range(x_shape_block)]
else:
data_im = np.pad(data_im, ((0, 0), (0, patch_shape[1] - y_shape), (0, 0)), 'constant')
coord_lst = [[x_dim * patch_shape[0], 0, (x_dim + 1) * patch_shape[0], patch_shape[1]] for x_dim in
range(x_shape_block)]
# pad the input data in the R-L direction
data_im = np.pad(data_im, ((0, x_pad), (0, 0), (0, 0)), 'constant')
# scale intensities between 0 and 255
data_im = scale_intensity(data_im)
x_CoM, y_CoM = None, None
z_sc_notDetected_cmpt = 0
for zz in range(data_im.shape[2]):
# if SC was detected at zz-1, we will start doing the detection on the block centered around the previously
# computed center of mass (CoM)
if x_CoM is not None:
z_sc_notDetected_cmpt = 0 # SC detected, cmpt set to zero
x_0, x_1 = _find_crop_start_end(x_CoM, patch_shape[0], data_im.shape[0])
y_0, y_1 = _find_crop_start_end(y_CoM, patch_shape[1], data_im.shape[1])
block = data_im[x_0:x_1, y_0:y_1, zz]
block_nn = np.expand_dims(np.expand_dims(block, 0), -1)
block_nn_norm = _normalize_data(block_nn, mean_train, std_train)
block_pred = model.predict(block_nn_norm, batch_size=BATCH_SIZE)
# coordinates manipulation due to the above padding and cropping
if x_1 > data.shape[0]:
x_end = data.shape[0]
x_1 = data.shape[0]
x_0 = data.shape[0] - patch_shape[0] if data.shape[0] > patch_shape[0] else 0
else:
x_end = patch_shape[0]
if y_1 > data.shape[1]:
y_end = data.shape[1]
y_1 = data.shape[1]
y_0 = data.shape[1] - patch_shape[1] if data.shape[1] > patch_shape[1] else 0
else:
y_end = patch_shape[1]
data[x_0:x_1, y_0:y_1, zz] = block_pred[0, :x_end, :y_end, 0]
# computation of the new center of mass
if np.max(data[:, :, zz]) > 0.5:
z_slice_out_bin = data[:, :, zz] > 0.5 # if the SC was detection
x_CoM, y_CoM = center_of_mass(z_slice_out_bin)
x_CoM, y_CoM = int(x_CoM), int(y_CoM)
else:
x_CoM, y_CoM = None, None
# if the SC was not detected at zz-1 or on the patch centered around CoM in slice zz, the entire cross-sectional
# slice is scanned
if x_CoM is None:
z_slice, x_CoM, y_CoM, coord_lst = scan_slice(data_im[:, :, zz], model,
mean_train, std_train,
coord_lst, patch_shape, data.shape[:2])
data[:, :, zz] = z_slice
z_sc_notDetected_cmpt += 1
# if the SC has not been detected on 10 consecutive z_slices, we stop the SC investigation
if z_sc_notDetected_cmpt > 10 and brain_bool:
logger.info('Brain section detected.')
break
# distance transform to deal with the harsh edges of the prediction boundaries (Dice)
data[:, :, zz][np.where(data[:, :, zz] < 0.5)] = 0
data[:, :, zz] = distance_transform_edt(data[:, :, zz])
if not np.any(data):
logger.error(
'\nSpinal cord was not detected using "-centerline cnn". Please try another "-centerline" method.\n')
sys.exit(1)
im_out.data = data
# z_max is used to reject brain sections
z_max = np.max(list(set(np.where(data)[2])))
if z_max == data.shape[2] - 1:
return im_out, None
else:
return im_out, z_max
def _normalize_data(data, mean, std):
"""Util function to normalized data based on learned mean and std."""
data -= mean
data /= std
return data
[docs]def segment_2d(model_fname, contrast_type, input_size, im_in):
"""
Segment data using 2D convolutions.
:return: seg_crop.data: ndarray float32: Output prediction
"""
seg_model = nn_architecture_seg(height=input_size[0],
width=input_size[1],
depth=2 if contrast_type != 't2' else 3,
features=32,
batchnorm=False,
dropout=0.0)
seg_model.load_weights(model_fname)
seg_crop = zeros_like(im_in, dtype=np.float32)
data_norm = im_in.data
# TODO: use sct_progress_bar
for zz in range(im_in.dim[2]):
# 2D CNN prediction
pred_seg = seg_model.predict(np.expand_dims(np.expand_dims(data_norm[:, :, zz], -1), 0),
batch_size=BATCH_SIZE)[0, :, :, 0]
seg_crop.data[:, :, zz] = pred_seg
return seg_crop.data
[docs]def segment_3d(model_fname, contrast_type, im_in):
"""
Perform segmentation with 3D convolutions.
:return: seg_crop.data: ndarray float32: Output prediction
"""
dct_patch_sc_3d = {'t2': {'size': (64, 64, 48), 'mean': 65.8562, 'std': 59.7999},
't2s': {'size': (96, 96, 48), 'mean': 87.0212, 'std': 64.425},
't1': {'size': (64, 64, 48), 'mean': 88.5001, 'std': 66.275}}
# load 3d model
seg_model = load_trained_model(model_fname)
out = zeros_like(im_in, dtype=np.float32)
# segment the spinal cord
z_patch_size = dct_patch_sc_3d[contrast_type]['size'][2]
z_step_keep = list(range(0, im_in.data.shape[2], z_patch_size))
# TODO: use sct_progress_bar
for zz in z_step_keep:
if zz == z_step_keep[-1]: # deal with instances where the im.data.shape[2] % patch_size_z != 0
patch_im = np.zeros(dct_patch_sc_3d[contrast_type]['size'])
z_patch_extracted = im_in.data.shape[2] - zz
patch_im[:, :, :z_patch_extracted] = im_in.data[:, :, zz:]
else:
z_patch_extracted = z_patch_size
patch_im = im_in.data[:, :, zz:z_patch_size + zz]
if np.any(patch_im): # Check if the patch is (not) empty, which could occur after a brain detection.
patch_norm = \
_normalize_data(patch_im, dct_patch_sc_3d[contrast_type]['mean'], dct_patch_sc_3d[contrast_type]['std'])
patch_pred_proba = \
seg_model.predict(np.expand_dims(np.expand_dims(patch_norm, 0), 0), batch_size=BATCH_SIZE)
# pred_seg_th = (patch_pred_proba > 0.5).astype(int)[0, 0, :, :, :]
pred_seg_th = patch_pred_proba[0, 0, :, :, :] # TODO: clarified variable (this is not thresholded!)
# TODO: add comment about what the code is doing below
if zz == z_step_keep[-1]:
out.data[:, :, zz:] = pred_seg_th[:, :, :z_patch_extracted]
else:
out.data[:, :, zz:z_patch_size + zz] = pred_seg_th
return out.data
[docs]def uncrop_image(ref_in, data_crop, x_crop_lst, y_crop_lst, z_crop_lst):
"""
Reconstruct the data from the cropped segmentation.
"""
seg_unCrop = zeros_like(ref_in, dtype=np.float32)
crop_size_x, crop_size_y = data_crop.shape[:2]
for i_z, zz in enumerate(z_crop_lst):
pred_seg = data_crop[:, :, zz]
x_start, y_start = int(x_crop_lst[i_z]), int(y_crop_lst[i_z])
x_end = x_start + crop_size_x if x_start + crop_size_x < seg_unCrop.dim[0] else seg_unCrop.dim[0]
y_end = y_start + crop_size_y if y_start + crop_size_y < seg_unCrop.dim[1] else seg_unCrop.dim[1]
seg_unCrop.data[x_start:x_end, y_start:y_end, zz] = pred_seg[0:x_end - x_start, 0:y_end - y_start]
return seg_unCrop
[docs]def deep_segmentation_spinalcord(im_image, contrast_type, ctr_algo='cnn', ctr_file=None, brain_bool=True,
kernel_size='2d', threshold_seg=None, remove_temp_files=1, verbose=1):
"""
Main pipeline for CNN-based segmentation of the spinal cord.
:param im_image:
:param contrast_type: {'t1', 't2', t2s', 'dwi'}
:param ctr_algo:
:param ctr_file:
:param brain_bool:
:param kernel_size:
:param threshold_seg: Binarization threshold (between 0 and 1) to apply to the segmentation prediction. Set to -1
for no binarization (i.e. soft segmentation output)
:param remove_temp_files:
:param verbose:
:return:
"""
if threshold_seg is None:
threshold_seg = THR_DEEPSEG[contrast_type]
# Display stuff
logger.info("Config deepseg_sc:")
logger.info(" Centerline algorithm: {}".format(ctr_algo))
logger.info(" Brain in image: {}".format(brain_bool))
logger.info(" Kernel dimension: {}".format(kernel_size))
logger.info(" Contrast: {}".format(contrast_type))
logger.info(" Threshold: {}".format(threshold_seg))
# create temporary folder with intermediate results
tmp_folder = TempFolder(verbose=verbose)
tmp_folder_path = tmp_folder.get_path()
if ctr_algo == 'file': # if the ctr_file is provided
tmp_folder.copy_from(ctr_file)
file_ctr = os.path.basename(ctr_file)
else:
file_ctr = None
tmp_folder.chdir()
# re-orient image to RPI
logger.info("Reorient the image to RPI, if necessary...")
original_orientation = im_image.orientation
# fname_orient = 'image_in_RPI.nii'
im_image.change_orientation('RPI')
# Resample image to 0.5mm in plane
im_image_res = \
resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')
fname_orient = 'image_in_RPI_res.nii'
im_image_res.save(fname_orient)
# find the spinal cord centerline - execute OptiC binary
logger.info("Finding the spinal cord centerline...")
_, im_ctl, im_labels_viewer = find_centerline(algo=ctr_algo,
image_fname=fname_orient,
contrast_type=contrast_type,
brain_bool=brain_bool,
folder_output=tmp_folder_path,
remove_temp_files=remove_temp_files,
centerline_fname=file_ctr)
if ctr_algo == 'file':
im_ctl = \
resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')
# crop image around the spinal cord centerline
logger.info("Cropping the image around the spinal cord...")
crop_size = 96 if (kernel_size == '3d' and contrast_type == 't2s') else 64
X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_image_res,
ctr_in=im_ctl,
crop_size=crop_size)
# normalize the intensity of the images
logger.info("Normalizing the intensity...")
im_norm_in = apply_intensity_normalization(im_in=im_crop_nii)
del im_crop_nii
if kernel_size == '2d':
# segment data using 2D convolutions
logger.info("Segmenting the spinal cord using deep learning on 2D patches...")
segmentation_model_fname = \
sct_dir_local_path('data', 'deepseg_sc_models', '{}_sc.h5'.format(contrast_type))
seg_crop = segment_2d(model_fname=segmentation_model_fname,
contrast_type=contrast_type,
input_size=(crop_size, crop_size),
im_in=im_norm_in)
elif kernel_size == '3d':
# segment data using 3D convolutions
logger.info("Segmenting the spinal cord using deep learning on 3D patches...")
segmentation_model_fname = \
sct_dir_local_path('data', 'deepseg_sc_models', '{}_sc_3D.h5'.format(contrast_type))
seg_crop = segment_3d(model_fname=segmentation_model_fname,
contrast_type=contrast_type,
im_in=im_norm_in)
# Postprocessing
seg_crop_postproc = np.zeros_like(seg_crop)
x_cOm, y_cOm = None, None
for zz in range(im_norm_in.dim[2]):
# Fill holes (only for binary segmentations)
if threshold_seg >= 0:
pred_seg_th = fill_holes_2d((seg_crop[:, :, zz] > threshold_seg).astype(int))
pred_seg_pp = keep_largest_object(pred_seg_th, x_cOm, y_cOm)
# Update center of mass for slice i+1
if 1 in pred_seg_pp:
x_cOm, y_cOm = center_of_mass(pred_seg_pp)
x_cOm, y_cOm = np.round(x_cOm), np.round(y_cOm)
else:
# If soft segmentation, do nothing
pred_seg_pp = seg_crop[:, :, zz]
seg_crop_postproc[:, :, zz] = pred_seg_pp # dtype is float32
# reconstruct the segmentation from the crop data
logger.info("Reassembling the image...")
im_seg = uncrop_image(ref_in=im_image_res,
data_crop=seg_crop_postproc,
x_crop_lst=X_CROP_LST,
y_crop_lst=Y_CROP_LST,
z_crop_lst=Z_CROP_LST)
# seg_uncrop_nii.save(add_suffix(fname_res, '_seg')) # for debugging
del seg_crop, seg_crop_postproc, im_norm_in
# resample to initial resolution
logger.info("Resampling the segmentation to the native image resolution using linear interpolation...")
im_seg_r = resampling.resample_nib(im_seg, image_dest=im_image, interpolation='linear')
if ctr_algo == 'viewer': # for debugging
im_labels_viewer.save(add_suffix(fname_orient, '_labels-viewer'))
# Binarize the resampled image (except for soft segmentation, defined by threshold_seg=-1)
if threshold_seg >= 0:
logger.info("Binarizing the resampled segmentation...")
im_seg_r.data = (im_seg_r.data > 0.5).astype(np.uint8)
# post processing step to z_regularized
im_seg_r_postproc = post_processing_volume_wise(im_seg_r)
# Change data type. By default, dtype is float32
if threshold_seg >= 0:
im_seg_r_postproc.change_type(np.uint8)
tmp_folder.chdir_undo()
# remove temporary files
if remove_temp_files:
logger.info("Remove temporary files...")
tmp_folder.cleanup()
# reorient to initial orientation
im_seg_r_postproc.change_orientation(original_orientation)
# copy q/sform from input image to output segmentation
im_seg.copy_qform_from_ref(im_image)
return im_seg_r_postproc, im_image_res, im_seg.change_orientation('RPI')