Source code for spinalcordtoolbox.deepseg_gm.model

# coding: utf-8
# This is the deepseg_gm model definition for the
# Spinal Cord Gray Matter Segmentation.
#
# Reference paper:
#     Perone, C. S., Calabrese, E., & Cohen-Adad, J. (2017).
#     Spinal cord gray matter segmentation using deep dilated convolutions.
#     URL: https://arxiv.org/abs/1710.01269

import keras.backend as K
from keras.models import Model
from keras.layers import Input
from keras.layers import Conv2D, Dropout
from keras.layers import RepeatVector, Reshape
from keras.layers import BatchNormalization
from keras.layers import concatenate, GlobalAveragePooling2D
from keras.optimizers import Adam

# Models
# Tuple of (model, metadata)
MODELS = {
    'challenge': ('challenge_model.hdf5', 'challenge_model.json'),
    'large': ('large_model.hdf5', 'large_model.json'),
}


[docs]def dice_coef(y_true, y_pred): """Dice coefficient specification :param y_true: ground truth. :param y_pred: predictions. """ dice_smooth_factor = 1.0 y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + dice_smooth_factor) / \ (K.sum(y_true_f) + K.sum(y_pred_f) + dice_smooth_factor)
[docs]def dice_coef_loss(y_true, y_pred): """Dice loss specification. :param y_true: ground truth. :param y_pred: predictions. """ return -dice_coef(y_true, y_pred)
[docs]def create_model(nfilters, input_size=(200, 200)): """Create the ASPP model. :param nfilters: number of filters at each block. :param input_size: the network input size (H, W) """ drop_rate_concat = 0.4 drop_rate_hidden = 0.4 bn_momentum = 0.1 input_height, input_width = input_size inputs = Input((input_height, input_width, 1)) conv1 = Conv2D(nfilters, (3, 3), activation='relu', padding='same')(inputs) conv1 = BatchNormalization(momentum=bn_momentum)(conv1) conv1 = Dropout(drop_rate_hidden)(conv1) conv1 = Conv2D(nfilters, (3, 3), activation='relu', padding='same')(conv1) conv1 = BatchNormalization(momentum=bn_momentum)(conv1) conv1 = Dropout(drop_rate_hidden)(conv1) # Rate 2 conv3 = Conv2D(nfilters, (3, 3), dilation_rate=(2, 2), activation='relu', padding='same', name="rate2_1")(conv1) conv3 = BatchNormalization(momentum=bn_momentum)(conv3) conv3 = Dropout(drop_rate_hidden)(conv3) conv3 = Conv2D(nfilters, (3, 3), dilation_rate=(2, 2), activation='relu', padding='same', name="rate2_2")(conv3) conv3 = BatchNormalization(momentum=bn_momentum)(conv3) conv3 = Dropout(drop_rate_hidden)(conv3) # Branches for ASPP # Branch for 1x1 conv3a = Conv2D(nfilters, (3, 3), activation='relu', padding='same', name="branch1x1_1")(conv3) conv3a = BatchNormalization(momentum=bn_momentum)(conv3a) conv3a = Dropout(drop_rate_hidden)(conv3a) conv3a = Conv2D(nfilters, (1, 1), activation='relu', padding='same', name="branch1x1_2")(conv3a) conv3a = BatchNormalization(momentum=bn_momentum)(conv3a) conv3a = Dropout(drop_rate_hidden)(conv3a) # Branch for 3x3 rate 6 conv4 = Conv2D(nfilters, (3, 3), dilation_rate=(6, 6), activation='relu', padding='same', name="rate6_1")(conv3) conv4 = BatchNormalization(momentum=bn_momentum)(conv4) conv4 = Dropout(drop_rate_hidden)(conv4) conv4 = Conv2D(nfilters, (3, 3), dilation_rate=(6, 6), activation='relu', padding='same', name="rate6_2")(conv4) conv4 = BatchNormalization(momentum=bn_momentum)(conv4) conv4 = Dropout(drop_rate_hidden)(conv4) # Branch for 3x3 rate 12 conv5 = Conv2D(nfilters, (3, 3), dilation_rate=(12, 12), activation='relu', padding='same', name="rate12_1")(conv3) conv5 = BatchNormalization(momentum=bn_momentum)(conv5) conv5 = Dropout(drop_rate_hidden)(conv5) conv5 = Conv2D(nfilters, (3, 3), dilation_rate=(12, 12), activation='relu', padding='same', name="rate12_2")(conv5) conv5 = BatchNormalization(momentum=bn_momentum)(conv5) conv5 = Dropout(drop_rate_hidden)(conv5) # Branch for 3x3 rate 18 conv6 = Conv2D(nfilters, (3, 3), dilation_rate=(18, 18), activation='relu', padding='same', name="rate18_1")(conv3) conv6 = BatchNormalization(momentum=bn_momentum)(conv6) conv6 = Dropout(drop_rate_hidden)(conv6) conv6 = Conv2D(nfilters, (3, 3), dilation_rate=(18, 18), activation='relu', padding='same', name="rate18_2")(conv6) conv6 = BatchNormalization(momentum=bn_momentum)(conv6) conv6 = Dropout(drop_rate_hidden)(conv6) # Branch for 3x3 rate 24 conv7 = Conv2D(nfilters, (3, 3), dilation_rate=(24, 24), activation='relu', padding='same', name="rate24_1")(conv3) conv7 = BatchNormalization(momentum=bn_momentum)(conv7) conv7 = Dropout(drop_rate_hidden)(conv7) conv7 = Conv2D(nfilters, (3, 3), dilation_rate=(24, 24), activation='relu', padding='same', name="rate24_2")(conv7) conv7 = BatchNormalization(momentum=bn_momentum)(conv7) conv7 = Dropout(drop_rate_hidden)(conv7) # Branch for the global context global_pool = GlobalAveragePooling2D()(conv1) global_pool = RepeatVector(input_height * input_width)(global_pool) global_pool = Reshape((input_height, input_width, nfilters))(global_pool) # Concatenation concat = concatenate([conv3a, conv4, conv5, conv6, global_pool, conv7], axis=3) concat = BatchNormalization(momentum=bn_momentum)(concat) concat = Dropout(drop_rate_concat)(concat) amort = Conv2D(64, (1, 1), activation='relu', padding='same', name="amort")(concat) amort = BatchNormalization(momentum=bn_momentum)(amort) amort = Dropout(drop_rate_hidden)(amort) predictions = Conv2D(1, (1, 1), activation='sigmoid', padding='same', name="predictions")(amort) model = Model(inputs=[inputs], outputs=[predictions]) opt = Adam(lr=0.001) model.compile(optimizer=opt, loss=dice_coef_loss, metrics=["accuracy"]) return model