# coding: utf-8
# This is the interface API for the deepseg_gm model
# that implements the model for the Spinal Cord Gray Matter Segmentation.
#
# Reference paper:
# Perone, C. S., Calabrese, E., & Cohen-Adad, J. (2017).
# Spinal cord gray matter segmentation using deep dilated convolutions.
# URL: https://arxiv.org/abs/1710.01269
import warnings
import json
import os
import sys
import io
import nibabel as nib
import numpy as np
# Avoid Keras logging
original_stderr = sys.stderr
if sys.hexversion < 0x03000000:
sys.stderr = io.BytesIO()
else:
sys.stderr = io.TextIOWrapper(io.BytesIO(), sys.stderr.encoding)
try:
from keras import backend as K
except Exception as e:
sys.stderr = original_stderr
raise
else:
sys.stderr = original_stderr
from spinalcordtoolbox import resampling, __data_dir__
from . import model
# Suppress warnings and TensorFlow logging
warnings.simplefilter(action='ignore', category=FutureWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
SMALL_INPUT_SIZE = 200
BATCH_SIZE = 4
[docs]def check_backend():
"""This function will check for the current backend and
then it will warn the user if the backend is theano."""
if K.backend() != 'tensorflow':
print("\nWARNING: you're using a Keras backend different than\n"
"Tensorflow, which is not recommended. Please verify\n"
"your configuration file according to: https://keras.io/backend/\n"
"to make sure you're using Tensorflow Keras backend.\n")
return K.backend()
[docs]class DataResource(object):
"""This class is responsible for resource file
management (such as loding models)."""
def __init__(self, dirname):
"""Initialize the resource with the directory
name context.
:param dirname: the root directory name.
"""
self.data_root = os.path.join(__data_dir__, dirname)
[docs] def get_file_path(self, filename):
"""Get the absolute file path based on the
data root directory.
:param filename: the filename.
"""
return os.path.join(self.data_root, filename)
[docs]class CroppedRegion(object):
"""This class holds cropping information about the volume
center crop.
"""
def __init__(self, original_shape, starts, crops):
"""Constructor for the CroppedRegion.
:param original_shape: the original volume shape.
:param starts: crop beginning (x, y).
:param crops: the crops (x, y).
"""
self.originalx = original_shape[0]
self.originaly = original_shape[1]
self.startx = starts[0]
self.starty = starts[1]
self.cropx = crops[0]
self.cropy = crops[1]
[docs] def pad(self, image):
"""This method will pad an image using the saved
cropped region.
:param image: the image to pad.
:return: padded image.
"""
bef_x = self.startx
aft_x = self.originalx - (self.startx + self.cropx)
bef_y = self.starty
aft_y = self.originaly - (self.starty + self.cropy)
padded = np.pad(image,
((bef_y, aft_y),
(bef_x, aft_x)),
mode="constant")
return padded
[docs]def crop_center(img, cropx, cropy):
"""This function will crop the center of the volume image.
:param img: image to be cropped.
:param cropx: x-coord of the crop.
:param cropy: y-coord of the crop.
:return: (cropped image, cropped region)
"""
y, x = img.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
if startx < 0 or starty < 0:
raise RuntimeError("Negative crop.")
cropped_region = CroppedRegion((x, y), (startx, starty),
(cropx, cropy))
return img[starty:starty + cropy,
startx:startx + cropx], cropped_region
[docs]def threshold_predictions(predictions, thr=0.999):
"""This method will threshold predictions.
:param thr: the threshold (if None, no threshold will
be applied).
:return: thresholded predictions
"""
if thr is None:
return predictions[:]
thresholded_preds = predictions[:]
low_values_indices = thresholded_preds <= thr
thresholded_preds[low_values_indices] = 0
low_values_indices = thresholded_preds > thr
thresholded_preds[low_values_indices] = 1
return thresholded_preds
[docs]def segment_volume(ninput_volume, model_name,
threshold=0.999, use_tta=False):
"""Segment a nifti volume.
:param ninput_volume: the input volume.
:param model_name: the name of the model to use.
:param threshold: threshold to be applied in predictions.
:param use_tta: whether TTA (test-time augmentation)
should be used or not.
:return: segmented slices.
"""
gmseg_model_challenge = DataResource('deepseg_gm_models')
model_path, metadata_path = model.MODELS[model_name]
metadata_abs_path = gmseg_model_challenge.get_file_path(metadata_path)
with open(metadata_abs_path) as fp:
metadata = json.load(fp)
volume_size = np.array(ninput_volume.shape[0:2])
small_input = (volume_size <= SMALL_INPUT_SIZE).any()
if small_input:
# Smaller than the trained net, don't crop
net_input_size = volume_size
else:
# larger sizer, crop at 200x200
net_input_size = (SMALL_INPUT_SIZE, SMALL_INPUT_SIZE)
deepgmseg_model = model.create_model(metadata['filters'],
net_input_size)
model_abs_path = gmseg_model_challenge.get_file_path(model_path)
deepgmseg_model.load_weights(model_abs_path)
volume_data = ninput_volume.get_data()
axial_slices = []
crops = []
for slice_num in range(volume_data.shape[2]):
data = volume_data[..., slice_num]
if not small_input:
data, cropreg = crop_center(data, SMALL_INPUT_SIZE,
SMALL_INPUT_SIZE)
crops.append(cropreg)
axial_slices.append(data)
axial_slices = np.asarray(axial_slices, dtype=np.float32)
axial_slices = np.expand_dims(axial_slices, axis=3)
normalization = VolumeStandardizationTransform()
axial_slices = normalization(axial_slices)
if use_tta:
pred_sampled = []
for i in range(8):
sampled_value = np.random.uniform(high=2.0)
sampled_axial_slices = axial_slices + sampled_value
preds = deepgmseg_model.predict(sampled_axial_slices,
batch_size=BATCH_SIZE,
verbose=True)
pred_sampled.append(preds)
preds = deepgmseg_model.predict(axial_slices, batch_size=BATCH_SIZE,
verbose=True)
pred_sampled.append(preds)
pred_sampled = np.asarray(pred_sampled)
pred_sampled = np.mean(pred_sampled, axis=0)
preds = threshold_predictions(pred_sampled, threshold)
else:
preds = deepgmseg_model.predict(axial_slices, batch_size=BATCH_SIZE,
verbose=True)
preds = threshold_predictions(preds, threshold)
pred_slices = []
# Un-cropping
for slice_num in range(preds.shape[0]):
pred_slice = preds[slice_num][..., 0]
if not small_input:
pred_slice = crops[slice_num].pad(pred_slice)
pred_slices.append(pred_slice)
pred_slices = np.asarray(pred_slices, dtype=np.uint8)
pred_slices = np.transpose(pred_slices, (1, 2, 0))
return pred_slices
[docs]def segment_file(input_filename, output_filename,
model_name, threshold, verbosity,
use_tta):
"""Segment a volume file.
:param input_filename: the input filename.
:param output_filename: the output filename.
:param model_name: the name of model to use.
:param threshold: threshold to apply in predictions (if None,
no threshold will be applied)
:param verbosity: the verbosity level.
:param use_tta: whether it should use TTA (test-time augmentation)
or not.
:return: the output filename.
"""
nii_original = nib.load(input_filename)
pixdim = nii_original.header["pixdim"][3]
target_resample = [0.25, 0.25, pixdim]
nii_resampled = resampling.resample_nib(
nii_original, new_size=target_resample, new_size_type='mm', interpolation='linear')
pred_slices = segment_volume(nii_resampled, model_name, threshold,
use_tta)
original_res = [
nii_original.header["pixdim"][1],
nii_original.header["pixdim"][2],
nii_original.header["pixdim"][3]]
volume_affine = nii_resampled.affine
volume_header = nii_resampled.header
nii_segmentation = nib.Nifti1Image(pred_slices, volume_affine,
volume_header)
nii_resampled_original = resampling.resample_nib(
nii_segmentation, new_size=original_res, new_size_type='mm', interpolation='linear')
res_data = nii_resampled_original.get_data()
# Threshold after resampling, only if specified
if threshold is not None:
res_data = threshold_predictions(res_data, 0.5)
nib.save(nii_resampled_original, output_filename)
return output_filename