Source code for sksurgerytf.models.rgb_unet

# -*- coding: utf-8 -*-

"""
Module to implement a semantic (pixelwise) segmentation using UNet on 512x512.
"""

#pylint: disable=line-too-long, too-many-instance-attributes, unsubscriptable-object, too-many-branches, too-many-arguments

import os
import sys
import glob
import logging
import datetime
import platform
import ssl
import shutil
import getpass
from pathlib import Path
import numpy as np
import cv2
from tensorflow import keras
from sksurgerytf import __version__
import sksurgerytf.callbacks.segmentation_history as sh
import sksurgerytf.utils.segmentation_statistics as ss

LOGGER = logging.getLogger(__name__)


[docs]class RGBUNet: """ Class to encapsulate RGB UNet semantic (pixelwise) segmentation network. Thanks to `Zhixuhao <https://github.com/zhixuhao/unet/blob/master/model.py>`_, and `ShawDa <https://github.com/ShawDa/unet-rgb/blob/master/unet.py>`_ for getting me started, and `Harshall Lamba <https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47>_, for further inspiration. """ def __init__(self, logs="logs/fit", data=None, working=None, omit=None, model=None, learning_rate=0.0001, epochs=50, batch_size=2, input_size=(512, 512, 3), patience=20 ): """ Class to run UNet on RGB images. If the constructor is called without a previously saved model, the data is loaded and a full training cycle is performed. If the constructor is called with a previously saved model, the model is loaded as is, with no further training. You can then call the test method to predict the output on new images. :param logs: relative path to folder to write tensorboard log files. :param data: root directory of training data. :param working: working directory for organising data. :param omit: patient identifier to omit, when doing Leave-One-Out. :param model: file name of previously saved model. :param learning_rate: float, default=0.001 for Adam optimiser. :param epochs: int, default=3, :param batch_size: int, default=4, :param input_size: Expected input size for network, default (512,512,3). :param patience: number of steps to tolerate non-improving accuracy """ LOGGER.info("Creating RGBUNet with log dir: %s.", str(logs)) LOGGER.info("Creating RGBUNet with data dir: %s.", str(data)) LOGGER.info("Creating RGBUNet with working dir: %s.", str(working)) LOGGER.info("Creating RGBUNet with omit: %s.", str(omit)) LOGGER.info("Creating RGBUNet with model file: %s.", str(model)) LOGGER.info("Creating RGBUNet with learning_rate: %s.", str(learning_rate)) LOGGER.info("Creating RGBUNet with epochs: %s.", str(epochs)) LOGGER.info("Creating RGBUNet with batch_size: %s.", str(batch_size)) LOGGER.info("Creating RGBUNet with input_size size: %s.", str(input_size)) LOGGER.info("Creating RGBUNet with patience: %s.", str(patience)) self.logs = logs self.data = data self.working = working self.omit = omit self.learning_rate = learning_rate self.epochs = epochs self.batch_size = batch_size self.input_size = input_size self.patience = patience self.model = None self.train_images_working_dir = None self.train_masks_working_dir = None self.train_generator = None self.number_training_samples = None self.validate_images_working_dir = None self.validate_masks_working_dir = None self.validate_generator = None self.number_validation_samples = None # To fix issues with SSL certificates on CI servers. ssl._create_default_https_context = ssl._create_unverified_context if model is not None: LOGGER.info("Loading Model") self.model = keras.models.load_model(model) LOGGER.info("Loaded Model") else: LOGGER.info("Building Model") self.model = self._build_model() LOGGER.info("Built Model") self.model.summary() if model is None and self.working is None: raise ValueError("You must specify a working (temp) directory") if model is None and self.data is None: raise ValueError("You must specify the data directory") if self.data is not None and self.working is not None: self._copy_data() self._load_data() self.train() def _copy_images(self, src_dir, dst_dir): """ Symlinks .png files from one directory to another. """ #pylint: disable=no-self-use for image_file in glob.iglob(os.path.join(src_dir, "*.png")): destination = os.path.join(dst_dir, os.path.basename( os.path.dirname(src_dir)) + "_" + os.path.basename(image_file)) os.symlink(image_file, destination) def _copy_data(self): """ Copies data from data directory to working directory. If the user is doing 'Leave-On-Out' then we validate on that case. """ # Look for each case in a sub-directory. sub_dirs = [f.path for f in os.scandir(self.data) if f.is_dir()] if not sub_dirs: raise ValueError("Couldn't find sub directories") sub_dirs.sort() if self.omit is not None: found_it = False for directory in sub_dirs: if os.path.basename(directory) == self.omit: found_it = True break if not found_it: raise ValueError("User requested to omit:" + self.omit + ", but it cannot be found in:" + self.data) # Always recreate working directory to avoid data leak. if os.path.exists(self.working): LOGGER.info("Removing working directory: %s", self.working) shutil.rmtree(self.working) # Keras still requires a sub-dir for the class name. class_name = 'object' self.train_images_working_dir = os.path.join(self.working, 'train', 'images', class_name) LOGGER.info("Creating directory: %s", self.train_images_working_dir) os.makedirs(self.train_images_working_dir) self.train_masks_working_dir = os.path.join(self.working, 'train', 'masks', class_name) LOGGER.info("Creating directory: %s", self.train_masks_working_dir) os.makedirs(self.train_masks_working_dir) self.validate_images_working_dir = os.path.join(self.working, 'validate', 'images', class_name) LOGGER.info("Creating directory: %s", self.validate_images_working_dir) os.makedirs(self.validate_images_working_dir) self.validate_masks_working_dir = os.path.join(self.working, 'validate', 'masks', class_name) LOGGER.info("Creating directory: %s", self.validate_masks_working_dir) os.makedirs(self.validate_masks_working_dir) for sub_dir in sub_dirs: images_sub_dir = os.path.join(sub_dir, 'images') mask_sub_dir = os.path.join(sub_dir, 'masks') if self.omit is not None and self.omit == os.path.basename(sub_dir): LOGGER.info("Sym-linking validate images from %s to %s", images_sub_dir, self.validate_images_working_dir) self._copy_images(images_sub_dir, self.validate_images_working_dir) LOGGER.info("Sym-linking validate masks from %s to %s", mask_sub_dir, self.validate_masks_working_dir) self._copy_images(mask_sub_dir, self.validate_masks_working_dir) else: LOGGER.info("Sym-linking train images from %s to %s", images_sub_dir, self.train_images_working_dir) self._copy_images(images_sub_dir, self.train_images_working_dir) LOGGER.info("Sym-linking train masks from %s to %s", mask_sub_dir, self.train_masks_working_dir) self._copy_images(mask_sub_dir, self.train_masks_working_dir) def _load_data(self): """ Sets up the Keras ImageDataGenerator to load images and masks together. """ train_data_gen_args = dict(rescale=1./255, horizontal_flip=True, vertical_flip=False, fill_mode='nearest', rotation_range=10, width_shift_range=0.05, height_shift_range=0.05, zoom_range=[0.9, 1.0] ) # This is for validation. We don't want data augmentation. validate_data_gen_args = dict(rescale=1./255) train_image_datagen = keras.preprocessing.image.ImageDataGenerator( **train_data_gen_args) train_mask_datagen = keras.preprocessing.image.ImageDataGenerator( **train_data_gen_args) validate_image_datagen = keras.preprocessing.image.ImageDataGenerator( **validate_data_gen_args) validate_mask_datagen = keras.preprocessing.image.ImageDataGenerator( **validate_data_gen_args) seed = 1 train_image_generator = train_image_datagen.flow_from_directory( os.path.dirname(self.train_images_working_dir), target_size=(self.input_size[0], self.input_size[1]), batch_size=self.batch_size, color_mode='rgb', class_mode=None, shuffle=True, seed=seed) train_mask_generator = train_mask_datagen.flow_from_directory( os.path.dirname(self.train_masks_working_dir), target_size=(self.input_size[0], self.input_size[1]), batch_size=self.batch_size, color_mode='grayscale', class_mode=None, shuffle=True, seed=seed) self.number_training_samples = len(train_image_generator.filepaths) self.train_generator = zip(train_image_generator, train_mask_generator) if self.omit is not None: validate_image_generator = validate_image_datagen.flow_from_directory( os.path.dirname(self.validate_images_working_dir), target_size=(self.input_size[0], self.input_size[1]), batch_size=self.batch_size, color_mode='rgb', class_mode=None, shuffle=False, seed=seed) validate_mask_generator = validate_mask_datagen.flow_from_directory( os.path.dirname(self.validate_masks_working_dir), target_size=(self.input_size[0], self.input_size[1]), batch_size=self.batch_size, color_mode='grayscale', class_mode=None, shuffle=False, seed=seed) self.number_validation_samples = len(validate_image_generator.filepaths) self.validate_generator = zip(validate_image_generator, validate_mask_generator) # pylint: disable=no-self-use def _create_2d_block(self, input_tensor, num_filters, kernel_size, batch_norm=True): model = keras.layers.Conv2D(num_filters, kernel_size, padding='same', kernel_initializer='he_normal')(input_tensor) if batch_norm: model = keras.layers.BatchNormalization()(model) model = keras.layers.Activation('relu')(model) model = keras.layers.Conv2D(num_filters, kernel_size, padding='same', kernel_initializer='he_normal')(model) if batch_norm: model = keras.layers.BatchNormalization()(model) model = keras.layers.Activation('relu')(model) return model def _build_model(self): """ Constructs the neural network, and compiles it. Currently, we are using a standard UNet on RGB images. """ dropout = 0.1 batch_norm = True kernel_size = 3 pooling_size = 2 num_filters = 64 inputs = keras.Input(self.input_size) # Left side of UNet conv1 = self._create_2d_block(inputs, num_filters * 1, kernel_size=kernel_size, batch_norm=batch_norm) pool1 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv1) pool1 = keras.layers.Dropout(dropout)(pool1) conv2 = self._create_2d_block(pool1, num_filters * 2, kernel_size=kernel_size, batch_norm=batch_norm) pool2 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv2) pool2 = keras.layers.Dropout(dropout)(pool2) conv3 = self._create_2d_block(pool2, num_filters * 4, kernel_size=kernel_size, batch_norm=batch_norm) pool3 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv3) pool3 = keras.layers.Dropout(dropout)(pool3) conv4 = self._create_2d_block(pool3, num_filters * 8, kernel_size=kernel_size, batch_norm=batch_norm) pool4 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv4) pool4 = keras.layers.Dropout(dropout)(pool4) # Bottom of UNet conv5 = self._create_2d_block(pool4, num_filters * 16, kernel_size=kernel_size, batch_norm=batch_norm) # Right side of UNet up6 = keras.layers.Conv2DTranspose(num_filters * 8, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv5) up6 = keras.layers.concatenate([up6, conv4]) up6 = keras.layers.Dropout(dropout)(up6) conv6 = self._create_2d_block(up6, num_filters * 8, kernel_size=3, batch_norm=batch_norm) up7 = keras.layers.Conv2DTranspose(num_filters * 4, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv6) up7 = keras.layers.concatenate([up7, conv3]) up7 = keras.layers.Dropout(dropout)(up7) conv7 = self._create_2d_block(up7, num_filters * 4, kernel_size=3, batch_norm=batch_norm) up8 = keras.layers.Conv2DTranspose(num_filters * 2, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv7) up8 = keras.layers.concatenate([up8, conv2]) up8 = keras.layers.Dropout(dropout)(up8) conv8 = self._create_2d_block(up8, num_filters * 2, kernel_size=3, batch_norm=batch_norm) up9 = keras.layers.Conv2DTranspose(num_filters * 1, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv8) up9 = keras.layers.concatenate([up9, conv1]) up9 = keras.layers.Dropout(dropout)(up9) conv9 = self._create_2d_block(up9, num_filters * 1, kernel_size=3, batch_norm=batch_norm) conv10 = keras.layers.Conv2D(1, 1, padding='same', activation='sigmoid')(conv9) return keras.models.Model(inputs=inputs, outputs=conv10)
[docs] def train(self): """ Method to train the neural network. Writes each epoch to tensorboard log files. :return: output of self.model.evaluate on validation set, or None. """ LOGGER.info("Training Model") optimiser = keras.optimizers.Adam(learning_rate=self.learning_rate) self.model.compile(optimizer=optimiser, loss='binary_crossentropy', metrics=['accuracy']) log_dir = os.path.join(Path(self.logs), datetime.datetime.now() .strftime("%Y%m%d-%H%M%S")) tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) if self.omit is not None: checkpoint_filename = "checkpoint-" + self.omit + ".hdf5" monitor = 'val_accuracy' else: checkpoint_filename = "checkpoint-all.hdf5" monitor = 'accuracy' filepath = os.path.join(Path(self.logs), checkpoint_filename) checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor=monitor, verbose=1, save_best_only=True, mode='max') early_stopping = keras.callbacks.EarlyStopping(monitor=monitor, patience=self.patience, restore_best_weights=True ) validation_steps = None if self.number_validation_samples is not None: validation_steps = self.number_validation_samples // self.batch_size segmentation_history = sh.SegmentationHistory(tensor_board_dir=log_dir, data=self.validate_generator, number_of_samples=self.number_validation_samples, desired_number_images=10) else: segmentation_history = sh.SegmentationHistory(tensor_board_dir=log_dir, data=self.train_generator, number_of_samples=self.number_training_samples, desired_number_images=10) # Note: EarlyStopping must be last. See https://github.com/keras-team/keras/issues/13381 callbacks_list = [tensorboard_callback, segmentation_history, checkpoint, early_stopping] LOGGER.info("Training. Train set=%s images, batch size=%s number of batches=%s", str(self.number_training_samples), str(self.batch_size), str(self.number_training_samples // self.batch_size)) if self.number_validation_samples is not None: LOGGER.info("Training. Validation set=%s images, batch size=%s number of batches=%s", str(self.number_validation_samples), str(self.batch_size), str(self.number_validation_samples // self.batch_size)) self.model.fit( self.train_generator, steps_per_epoch=self.number_training_samples // self.batch_size, epochs=self.epochs, verbose=1, validation_data=self.validate_generator, # this will be None if you didn't specify self.omit validation_steps=validation_steps, # and then this won't matter if the above is None. callbacks=callbacks_list ) result = None if self.validate_generator is not None and self.number_validation_samples is not None: result = self.model.evaluate(self.validate_generator, steps=self.number_validation_samples, verbose=2 ) return result
[docs] def predict(self, rgb_image): """ Method to test a single image. Image resized to match network, segmented and then resized back to match the input size. :param rgb_image: 3 channel RGB, [0-255], uchar. :return: single channel, [0=bg|255=fg]. """ img = rgb_image * 1. / 255 resized = cv2.resize(img, (self.input_size[1], self.input_size[0])) resized = np.expand_dims(resized, axis=0) predictions = self.model.predict(resized) mask = predictions[0] # float 0-1 mask = (mask > 0.5).astype(np.ubyte) * 255 # threshold 0.5, cast to uchar, rescale [0|255] mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST) return mask
[docs] def save_model(self, filename): """ Method to save the whole trained network to disk. :param filename: file to save to. """ self.model.save(filename)
[docs]def run_rgb_unet_model(logs, data, working, omit, model, save, test, prediction, epochs, batch_size, learning_rate, patience ): """ Helper function to run the RGBUnet model from the command line entry point. :param logs: directory for log files for tensorboard. :param data: root directory of training data. :param working: working directory for organising data. :param omit: patient identifier to omit, when doing Leave-One-Out. :param model: file of previously saved model. :param save: file to save model to. :param test: input image to test. :param prediction: output image, the result of the prediction on test image. :param epochs: number of epochs. :param batch_size: batch size. :param learning_rate: learning rate for optimizer. :param patience: number of steps to tolerate non-improving accuracy """ now = datetime.datetime.now() date_format = now.today().strftime("%Y-%m-%d") time_format = now.time().strftime("%H-%M-%S") logfile_name = 'rgbunet-' \ + date_format \ + '-' \ + time_format \ + '-' \ + str(os.getpid()) \ + '.log' formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) handler = logging.StreamHandler() handler.setLevel(logging.INFO) handler.setFormatter(formatter) root_logger.addHandler(handler) file_handler = logging.FileHandler(logfile_name) file_handler.setFormatter(formatter) root_logger.addHandler(file_handler) username = getpass.getuser() LOGGER.info("Starting RGBUNet version: %s", __version__) LOGGER.info("Starting RGBUNet with username: %s.", username) LOGGER.info("Starting RGBUNet with platform: %s.", str(platform.uname())) LOGGER.info("Starting RGBUNet with cwd: %s.", os.getcwd()) LOGGER.info("Starting RGBUNet with path: %s.", sys.path) LOGGER.info("Starting RGBUNet with save: %s.", save) LOGGER.info("Starting RGBUNet with test: %s.", test) LOGGER.info("Starting RGBUNet with prediction: %s.", prediction) # No point loading network to test an image, if command line args wrong. # So, check this up front. if test is not None: if prediction is None: raise ValueError("If you specify a test parameter, you must also " "specify the prediction parameter.") if test == prediction: raise ValueError("If you specify a test parameter, the value for " "the prediction parameter must be different.") if os.path.isfile(prediction) or os.path.isdir(prediction): raise ValueError("The prediction parameter should " "be a new file or directory") if save is not None: dirname = os.path.dirname(save) if os.path.exists(dirname) and not os.path.isdir(dirname): raise ValueError("Path:" + str(dirname) + " exists, but is not a directory") if dirname is not None and len(dirname) > 0 and \ not os.path.exists(dirname): os.makedirs(dirname) rgbunet = RGBUNet(logs, data, working, omit, model, learning_rate=learning_rate, epochs=epochs, batch_size=batch_size, patience=patience ) if save is not None: rgbunet.save_model(save) if test is not None: if os.path.isfile(test): test_files = [test] elif os.path.isdir(test): test_files = ss.get_sorted_files_from_dir(test) if not os.path.exists(prediction): os.makedirs(prediction) else: raise ValueError("Invalid value for test parameter ") for test_file in test_files: img = cv2.imread(test_file) start_time = datetime.datetime.now() img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = rgbunet.predict(img) end_time = datetime.datetime.now() time_taken = (end_time - start_time).total_seconds() LOGGER.info("Prediction on %s took %s seconds.", test_file, str(time_taken)) if os.path.isdir(prediction): cv2.imwrite( os.path.join(prediction, os.path.basename(test_file)), mask) else: cv2.imwrite(prediction, mask)