#!/usr/bin/env python
# -*- coding: utf-8
# Functions dealing with deepseg_lesion
import os
import logging
import numpy as np
from scipy.interpolate.interpolate import interp1d
import nibabel as nib
from spinalcordtoolbox.image import Image, add_suffix, zeros_like, empty_like
from spinalcordtoolbox.deepseg_sc.core import find_centerline, crop_image_around_centerline, uncrop_image, _normalize_data
from spinalcordtoolbox import resampling
from spinalcordtoolbox.utils import sct_dir_local_path, TempFolder
from spinalcordtoolbox.deepseg_sc.cnn_models_3d import load_trained_model
logger = logging.getLogger(__name__)
BATCH_SIZE = 4
MODEL_LST = ['t2', 't2_ax', 't2s']
[docs]def apply_intensity_normalization_model(img, landmarks_lst):
"""Description: apply the learned intensity landmarks to the input image."""
percent_decile_lst = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99]
vals = list(img)
landmarks_lst_cur = np.percentile(vals, q=percent_decile_lst)
# create linear mapping models for the percentile segments to the learned standard intensity space
linear_mapping = interp1d(landmarks_lst_cur, landmarks_lst, bounds_error=False)
# transform the input image intensity values
output = linear_mapping(img)
# treat image intensity values outside of the cut-off percentiles range separately
below_mapping = exp_model(landmarks_lst_cur[:2], landmarks_lst[:2], landmarks_lst[0])
output[img < landmarks_lst_cur[0]] = below_mapping(img[img < landmarks_lst_cur[0]])
above_mapping = exp_model(landmarks_lst_cur[-3:-1], landmarks_lst[-3:-1], landmarks_lst[-1])
output[img > landmarks_lst_cur[-1]] = above_mapping(img[img > landmarks_lst_cur[-1]])
return output.astype(np.float32)
[docs]def exp_model(xs, ys, s2):
"""
FIXME doc
"""
x1, x2 = xs
y1, y2 = ys
m = (y2 - y1) / (x2 - x1)
b = y1 - (m * x1)
mu90 = x2
# y2 = alpha + beta * exp(gamma * x)
alpha = s2
omega = m * mu90 - s2 + b
beta = omega * np.exp(-m * mu90 * 1.0 / omega)
gamma = m * 1.0 / omega
return lambda x: alpha + beta * np.exp(gamma * x)
[docs]def apply_intensity_normalization(img, contrast):
"""Standardize the intensity range."""
data2norm = img.data.astype(np.float32)
dct_norm = {'t2': [0.000000, 136.832187, 312.158435, 448.968030, 568.657779, 696.671586, 859.221138, 1074.463414, 1373.289174, 1811.522669, 2611.000000],
't2_ax': [0.000000, 112.195357, 291.611185, 446.727066, 581.103970, 702.979079, 833.318257, 1011.856313, 1268.801813, 1687.137075, 2611.000000],
't2s': [0.000000, 123.246969, 226.422561, 338.361023, 532.341924, 788.693675, 1096.975553, 1407.979466, 1716.524530, 2079.788451, 2611.000000]}
img_normalized = empty_like(img)
img_normalized.data = apply_intensity_normalization_model(data2norm, dct_norm[contrast])
return img_normalized
[docs]def segment_3d(model_fname, contrast_type, im):
"""Perform segmentation with 3D convolutions."""
dct_patch_3d = {'t2': {'size': (48, 48, 48), 'mean': 871.309, 'std': 557.916},
't2_ax': {'size': (48, 48, 48), 'mean': 835.592, 'std': 528.386},
't2s': {'size': (48, 48, 48), 'mean': 1011.31, 'std': 678.985}}
# load 3d model
seg_model = load_trained_model(model_fname)
out_data = np.zeros(im.data.shape)
# segment the spinal cord
z_patch_size = dct_patch_3d[contrast_type]['size'][2]
z_step_keep = list(range(0, im.data.shape[2], z_patch_size))
for zz in z_step_keep:
if zz == z_step_keep[-1]: # deal with instances where the im.data.shape[2] % patch_size_z != 0
patch_im = np.zeros(dct_patch_3d[contrast_type]['size'])
z_patch_extracted = im.data.shape[2] - zz
patch_im[:, :, :z_patch_extracted] = im.data[:, :, zz:]
else:
z_patch_extracted = z_patch_size
patch_im = im.data[:, :, zz:z_patch_size + zz]
if np.any(patch_im): # Check if the patch is (not) empty, which could occur after a brain detection.
patch_norm = _normalize_data(patch_im, dct_patch_3d[contrast_type]['mean'], dct_patch_3d[contrast_type]['std'])
patch_pred_proba = seg_model.predict(np.expand_dims(np.expand_dims(patch_norm, 0), 0), batch_size=BATCH_SIZE)
pred_seg_th = (patch_pred_proba > 0.1).astype(int)[0, 0, :, :, :]
if zz == z_step_keep[-1]:
out_data[:, :, zz:] = pred_seg_th[:, :, :z_patch_extracted]
else:
out_data[:, :, zz:z_patch_size + zz] = pred_seg_th
out = zeros_like(im, dtype=np.uint8)
out.data = out_data
return out.copy()
[docs]def deep_segmentation_MSlesion(im_image, contrast_type, ctr_algo='svm', ctr_file=None, brain_bool=True,
remove_temp_files=1, verbose=1):
"""
Segment lesions from MRI data.
:param im_image: Image() object containing the lesions to segment
:param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
:param ctr_algo: Algo to find the centerline. See sct_get_centerline
:param ctr_file: Centerline or segmentation (optional)
:param brain_bool: If brain if present or not in the image.
:param remove_temp_files:
:return:
"""
# create temporary folder with intermediate results
tmp_folder = TempFolder(verbose=verbose)
tmp_folder_path = tmp_folder.get_path()
if ctr_algo == 'file': # if the ctr_file is provided
tmp_folder.copy_from(ctr_file)
file_ctr = os.path.basename(ctr_file)
else:
file_ctr = None
tmp_folder.chdir()
fname_in = im_image.absolutepath
# re-orient image to RPI
logger.info("Reorient the image to RPI, if necessary...")
original_orientation = im_image.orientation
# fname_orient = 'image_in_RPI.nii'
im_image.change_orientation('RPI')
input_resolution = im_image.dim[4:7]
# Resample image to 0.5mm in plane
im_image_res = \
resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')
fname_orient = 'image_in_RPI_res.nii'
im_image_res.save(fname_orient)
# find the spinal cord centerline - execute OptiC binary
logger.info("\nFinding the spinal cord centerline...")
contrast_type_ctr = contrast_type.split('_')[0]
_, im_ctl, im_labels_viewer = find_centerline(algo=ctr_algo,
image_fname=fname_orient,
contrast_type=contrast_type_ctr,
brain_bool=brain_bool,
folder_output=tmp_folder_path,
remove_temp_files=remove_temp_files,
centerline_fname=file_ctr)
if ctr_algo == 'file':
im_ctl = \
resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')
# crop image around the spinal cord centerline
logger.info("\nCropping the image around the spinal cord...")
crop_size = 48
X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_image_res,
ctr_in=im_ctl,
crop_size=crop_size)
del im_ctl
# normalize the intensity of the images
logger.info("Normalizing the intensity...")
im_norm_in = apply_intensity_normalization(img=im_crop_nii, contrast=contrast_type)
del im_crop_nii
# resample to 0.5mm isotropic
fname_norm = add_suffix(fname_orient, '_norm')
im_norm_in.save(fname_norm)
fname_res3d = add_suffix(fname_norm, '_resampled3d')
resampling.resample_file(fname_norm, fname_res3d, '0.5x0.5x0.5', 'mm', 'linear',
verbose=0)
# segment data using 3D convolutions
logger.info("\nSegmenting the MS lesions using deep learning on 3D patches...")
segmentation_model_fname = sct_dir_local_path('data', 'deepseg_lesion_models',
'{}_lesion.h5'.format(contrast_type))
fname_seg_crop_res = add_suffix(fname_res3d, '_lesionseg')
im_res3d = Image(fname_res3d)
seg_im = segment_3d(model_fname=segmentation_model_fname,
contrast_type=contrast_type,
im=im_res3d.copy())
seg_im.save(fname_seg_crop_res)
del im_res3d, seg_im
# resample to the initial pz resolution
fname_seg_res2d = add_suffix(fname_seg_crop_res, '_resampled2d')
initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
resampling.resample_file(fname_seg_crop_res, fname_seg_res2d, initial_2d_resolution,
'mm', 'linear', verbose=0)
seg_crop = Image(fname_seg_res2d)
# reconstruct the segmentation from the crop data
logger.info("\nReassembling the image...")
seg_uncrop_nii = uncrop_image(ref_in=im_image_res, data_crop=seg_crop.copy().data, x_crop_lst=X_CROP_LST,
y_crop_lst=Y_CROP_LST, z_crop_lst=Z_CROP_LST)
fname_seg_res_RPI = add_suffix(fname_in, '_res_RPI_seg')
seg_uncrop_nii.save(fname_seg_res_RPI)
del seg_crop
# resample to initial resolution
logger.info("Resampling the segmentation to the original image resolution...")
initial_resolution = 'x'.join([str(input_resolution[0]), str(input_resolution[1]), str(input_resolution[2])])
fname_seg_RPI = add_suffix(fname_in, '_RPI_seg')
resampling.resample_file(fname_seg_res_RPI, fname_seg_RPI, initial_resolution,
'mm', 'linear', verbose=0)
seg_initres_nii = Image(fname_seg_RPI)
if ctr_algo == 'viewer': # resample and reorient the viewer labels
im_labels_viewer_nib = nib.nifti1.Nifti1Image(im_labels_viewer.data, im_labels_viewer.hdr.get_best_affine())
im_viewer_r_nib = resampling.resample_nib(im_labels_viewer_nib, new_size=input_resolution, new_size_type='mm',
interpolation='linear')
im_viewer = Image(im_viewer_r_nib.get_data(), hdr=im_viewer_r_nib.header, orientation='RPI',
dim=im_viewer_r_nib.header.get_data_shape()).change_orientation(original_orientation)
else:
im_viewer = None
if verbose == 2:
fname_res_ctr = add_suffix(fname_orient, '_ctr')
resampling.resample_file(fname_res_ctr, fname_res_ctr, initial_resolution,
'mm', 'linear', verbose=0)
im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(original_orientation)
else:
im_image_res_ctr_downsamp = None
# binarize the resampled image to remove interpolation effects
logger.info("\nBinarizing the segmentation to avoid interpolation effects...")
thr = 0.1
seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0
# change data type
seg_initres_nii.change_type(np.uint8)
# reorient to initial orientation
logger.info("\nReorienting the segmentation to the original image orientation...")
tmp_folder.chdir_undo()
# remove temporary files
if remove_temp_files:
logger.info("\nRemove temporary files...")
tmp_folder.cleanup()
# reorient to initial orientation
return seg_initres_nii.change_orientation(original_orientation), im_viewer, im_image_res_ctr_downsamp