Source code for sksurgerytorch.models.volume_to_surface

"""V2SNet Model Impementation"""

import logging
import numpy as np
import torch

from sksurgerytorch.models.volume_to_surface_model import Model

LOGGER = logging.getLogger(__name__)

#pylint:disable=unused-variable, super-with-arguments, invalid-name

[docs]class Volume2SurfaceCNN: """Class to encapsulate network form 'Non-Rigid Volume to Surface \ Registration using a Data-Driven Biomechanical Model'. Thanks to \ `Micha Pfieffer <https://gitlab.com/nct_tso_public/Volume2SurfaceCNNo>`_,\ for their network implementation. :param mask: If true, use masking :type mask: bool :param weights: Path to trained model weights (.tar file) :type weights: str """ def __init__(self, mask: bool = True, weights: str = None, grid_size: int = 64): if torch.cuda.is_available(): self.device = torch.device("cuda:0") LOGGER.info("Using GPU") else: self.device = torch.device("cpu") LOGGER.info("Using CPU") self.mask = mask self.grid_size = grid_size self.model = Model(mask) if weights is not None: optimizer = torch.optim.AdamW(self.model.parameters(), lr=0) checkpoint = torch.load(weights, map_location=self.device) self.model.load_state_dict(checkpoint["model"]) if "optimizer" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer"]) self.model.to(self.device) self.model.eval()
[docs] def predict(self, preoperative: np.ndarray, intraoperative: np.ndarray) -> np.ndarray: """Predict the displacement field between model and surface. :param preoperative: Preoperative surface/point cloud :type preoperative: np.ndarray :param intraoperative: Intraoperative surface/point cloud :type intraoperative: np.ndarray :return: Displacement field :rtype: np.ndarray """ gs = self.grid_size intraoperative = np.reshape(intraoperative, (gs, gs, gs, 1)) intraoperative = np.transpose(intraoperative, (3, 0, 1, 2)) preoperative = np.reshape(preoperative, (gs, gs, gs, 1)) preoperative = np.transpose(preoperative, (3, 0, 1, 2)) preoperative = torch.FloatTensor(preoperative).to(self.device) intraoperative = torch.FloatTensor(intraoperative).to(self.device) preoperative = preoperative.unsqueeze(0) intraoperative = intraoperative.unsqueeze(0) mask = (preoperative < 0) # If no values in the SDF are lower than 0 then this is not a valid # mesh. if not mask.any(): raise IOError( "Sample contains no internal points (no valid signed distance\ function?)") out64, out32, out16, out8 = self.model(preoperative, intraoperative) estimated_displacmement = (out64).squeeze() mask64 = (preoperative <= 0).float() meanDisplacement = torch.sum( torch.norm( out64 * mask64, dim=1)) / torch.sum(mask64) maxDisplacement = torch.max(torch.norm(out64 * mask64, dim=1)) # This is the same sequence of commands as in Model/data.py in original # v2snet, saveSample() function np_displacement = estimated_displacmement.detach().cpu().numpy() np_displacement = np.transpose(np_displacement, (1, 2, 3, 0)) return np_displacement.reshape(gs**3, -1)