# -*- coding: utf-8 -*-
"""
Module to implement Hierarchical Deep Stereo Matching on High Resolution Images
network.
"""
# pylint:disable=invalid-name, line-too-long, no-else-return
# pylint:disable=useless-object-inheritance, consider-using-from-import
# pylint:disable=consider-using-f-string
import logging
import time
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as transforms
from sksurgerytorch.models.high_res_stereo_model import disparityregression
from sksurgerytorch.models.high_res_stereo_model import HSMNet_model
LOGGER = logging.getLogger(__name__)
[docs]class HSMNet:
"""Class to encapsulate network form 'Hierarchical Deep Stereo Matching on \
High Resolution Images'.
Thanks to \
`Gengshang Yang <https://github.com/gengshan-y/high-res-stereo>`_, for \
their network implementation.
:param max_disp: Maximum number of disparity levels
:param entropy_threshold: Pixels with entropy above this value will be \
ignored in the disparity map. Disabled if set to -1.
:param level: Set to 1, 2 or 3 to trade off quality of depth estimation \
against runtime. 1 = best depth estimation, longer runtime, \
3 = worst depth estimation, fastest runtime.
:param scale_factor: Images can be resized before passing to the network, \
for perfomance impromvents. This sets the scale factor.
:param weights: Path to trained model weights (.tar file)
"""
def __init__(self,
max_disp: int = 255,
entropy_threshold: float = -1,
level: int = 1,
scale_factor: float = 0.5,
weights=None,
):
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
LOGGER.info("Using GPU")
else:
self.device = torch.device("cpu")
LOGGER.info("Using CPU")
print(self.device)
self.max_disp = max_disp
self.scale_factor = scale_factor
self.entropy_threshold = entropy_threshold
self.level = level
self.model = HSMNet_model(max_disp, entropy_threshold, self.device,
level)
self.model = nn.DataParallel(self.model)
self.model.to(self.device)
self.model.eval()
self.pred_disp = None
self.entropy = None
if weights:
LOGGER.info("Loading weights from %s", weights)
pretrained_dict = torch.load(weights, map_location=self.device)
pretrained_dict['state_dict'] = {
k: v for k, v in pretrained_dict['state_dict'].items() if 'disp' not in k}
self.model.load_state_dict(
pretrained_dict['state_dict'], strict=False)
LOGGER.info("Loaded weights")
print('Number of model parameters: {}'.format(
sum([p.data.nelement() for p in self.model.parameters()])))
[docs] def predict(self,
left_image: np.ndarray,
right_image: np.ndarray) -> np.ndarray:
"""Predict disparity from a pair of stereo images.
:param left_image: Left stereo image, 3 channel RGB
:type left_image: np.ndarray
:param right_image: Right stero image, 3 channel RGB
:type right_image: np.ndarray
:return: Predicted disparity, grayscale
:rtype: np.ndarray
"""
__imagenet_stats = {'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
t_list = [
toTensorLegacy(),
transforms.Normalize(**__imagenet_stats),
]
processed = transforms.Compose(t_list)
disp_scaled = int(self.max_disp * self.scale_factor // 64 * 64)
self.model.module.maxdisp = disp_scaled
if self.model.module.maxdisp == 64:
self.model.module.maxdisp = 128
self.model.module.disp_reg8 = disparityregression(
self.model.module.maxdisp, 16).to(self.device)
self.model.module.disp_reg16 = disparityregression(
self.model.module.maxdisp, 16).to(self.device)
self.model.module.disp_reg32 = disparityregression(
self.model.module.maxdisp, 32).to(self.device)
self.model.module.disp_reg64 = disparityregression(
self.model.module.maxdisp, 64).to(self.device)
LOGGER.info("Model.module.maxdisp %s", self.model.module.maxdisp)
orig_img_size = left_image.shape[:2]
# resize
imgL_o = cv2.resize(
left_image,
None,
fx=self.scale_factor,
fy=self.scale_factor,
interpolation=cv2.INTER_CUBIC)
imgR_o = cv2.resize(
right_image,
None,
fx=self.scale_factor,
fy=self.scale_factor,
interpolation=cv2.INTER_CUBIC)
imgL = processed(imgL_o).numpy()
imgR = processed(imgR_o).numpy()
imgL = np.reshape(imgL, [1, 3, imgL.shape[1], imgL.shape[2]])
imgR = np.reshape(imgR, [1, 3, imgR.shape[1], imgR.shape[2]])
# fast pad
max_h = int(imgL.shape[2] // 64 * 64)
max_w = int(imgL.shape[3] // 64 * 64)
if max_h < imgL.shape[2]:
max_h += 64
if max_w < imgL.shape[3]:
max_w += 64
top_pad = max_h - imgL.shape[2]
left_pad = max_w - imgL.shape[3]
imgL = np.lib.pad(imgL, ((0, 0), (0, 0), (top_pad, 0),
(0, left_pad)), mode='constant', constant_values=0)
imgR = np.lib.pad(imgR, ((0, 0), (0, 0), (top_pad, 0),
(0, left_pad)), mode='constant', constant_values=0)
imgL = Variable(torch.FloatTensor(imgL).to(self.device))
imgR = Variable(torch.FloatTensor(imgR).to(self.device))
LOGGER.info("Predicting disparity")
with torch.no_grad():
if self.device.type.startswith("cuda"):
torch.cuda.synchronize()
start_time = time.time()
self.pred_disp, self.entropy = self.model(imgL, imgR)
if self.device.type.startswith("cuda"):
torch.cuda.synchronize()
ttime = (time.time() - start_time)
print('time = %.2f' % (ttime * 1000))
self.pred_disp = torch.squeeze(self.pred_disp).data.cpu().numpy()
LOGGER.info("Predicted disparity")
top_pad = max_h - imgL_o.shape[0]
left_pad = max_w - imgL_o.shape[1]
self.entropy = self.entropy[top_pad:,
:self.pred_disp.shape[1] - left_pad].cpu().numpy()
self.pred_disp = self.pred_disp[top_pad:,
:self.pred_disp.shape[1] - left_pad]
# resize to highres
self.pred_disp = cv2.resize(
self.pred_disp /
self.scale_factor,
(orig_img_size[1],
orig_img_size[0]),
interpolation=cv2.INTER_LINEAR)
# pylint:disable=assignment-from-no-return
# clip while keep inf
invalid = np.logical_or(
self.pred_disp == np.inf,
self.pred_disp != self.pred_disp)
self.pred_disp[invalid] = np.inf
torch.cuda.empty_cache()
return self.pred_disp, self.entropy
[docs]class toTensorLegacy(object):
"""
.
"""
def __call__(self, pic):
"""
Args:
pic (PIL or numpy.ndarray): Image to be converted to tensor
Returns:
Tensor: Converted image.
"""
# pylint:disable=no-member
if isinstance(pic, np.ndarray):
# This is what TorchVision 0.2.0 returns for transforms.toTensor()
# for np.ndarray
return torch.from_numpy(pic.transpose((2, 0, 1))).float().div(255)
else:
return transforms.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
[docs]def run_hsmnet_model(max_disp,
entropy_threshold,
level,
scale_factor,
weights,
left_image,
right_image,
output_file
):
""" This is for the command line entry point """
network = HSMNet(max_disp=max_disp,
entropy_threshold=entropy_threshold,
level=level,
scale_factor=scale_factor,
weights=weights)
left = cv2.imread(left_image)
right = cv2.imread(right_image)
# pylint:disable=unused-variable
disp, entropy = network.predict(left, right)
cv2.imwrite(output_file, disp)