generate images for testing

def one_hot(mask):
        return torch.cat([(mask == i).float() for i in mask.unique()], 1)
mask = torch.zeros(1, 1, 5, 10, 10)
mask[:,:, 2:4, 3:7, 3:7] = 1

pred_1 = one_hot(mask) 
pred_2 = one_hot(mask).flip(1) # prediction completely off
pred_3 = one_hot(mask).flip(2) # prediction partly off

show_images_3d(mask) # show_images_3d assumes [B x D x H x W]
show_images_3d(pred_1.squeeze())
show_images_3d(pred_2.squeeze())
show_images_3d(pred_3.squeeze())
../faimed3d/basics.py:256: UserWarning: Object is not a rank 3 tensor but a rank 4 tensor. Removing the first dimension
  warn('Object is not a rank 3 tensor but a rank {} tensor. Removing the first dimension'.format(t.ndim))

Loss functions

DICE Loss

class BaseLoss[source]

BaseLoss(weights=None, avg='macro')

Base class for loss functions

Args:
    targ:    A tensor of shape [B, C, D, H, W].
    pred:    A tensor of shape [B, C, D, H, W]. Corresponds to
             the raw output or logits of the model.
    weights: A list ot tuple, giving weights for each class or None
    avg:     'macro' computes loss for each B x C and averages the losses
             'micro' computes loss for each B and acverages the losses
Returns:
    loss:    computed loss (scalar)

class DiceLoss[source]

DiceLoss(method='miletari', alpha=0.5, beta=0.5, eps=1e-07, smooth=0.0, *args, **kwargs) :: BaseLoss

Simple DICE loss as described in:
    https://arxiv.org/pdf/1911.02855.pdf

Computes the Sørensen–Dice loss. Larger is better.
Note that PyTorch optimizers minimize a loss. So the loss is subtracted from 1.

Args:
    inherited from `BaseLoss`
    targ:    A tensor of shape [B, C, D, H, W].
    pred:    A tensor of shape [B, C, D, H, W]. Corresponds to
             the raw output or logits of the model.
    weights: A list ot tuple, giving weights for each class or None
    avg:     'macro' computes loss for each B x C and averages the losses
             'micro' computes loss for each B and acverages the losses

    Unique for `DiceLoss`
    method:  The method how the DICE score should be calculated.
                "simple"   = standard DICE loss
                "miletari" = squared denominator for faster convergence
                "tversky"  = variant of the DICE loss which allows to weight FP vs FN.
    alpha, beta: weights for FP and FN for "tversky" loss if both values are 0.5 the
             "tversky" loss corresponds to the "simple" DICE loss
    smooth:  Added smoothing factor.
    eps: added to the denominator for numerical stability (acoid division by 0).
Returns:
    dice_loss: the Sørensen–Dice loss

class DiceLoss2[source]

DiceLoss2(epsilon=1e-05) :: Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.

:ivar training: Boolean represents whether this module is in training or
                evaluation mode.
:vartype training: bool
DiceLoss()(pred_1, mask), DiceLoss()(pred_2, one_hot(mask)), DiceLoss()(pred_3, one_hot(mask))
(tensor(0.), tensor(1.), tensor(0.2671))
DiceLoss2()(pred_1, mask)#, DiceLoss2()(pred_2, mask), DiceLoss2()(pred_3, mask)
tensor(0.)

MCC loss

From Wikipedia (https://en.wikipedia.org/wiki/Matthews_correlation_coefficient):

The coefficient takes into account true and false positives and negatives and is generally regarded as a balanced measure which can be used even if the classes are of very different sizes The MCC is in essence a correlation coefficient between the observed and predicted binary classifications; it returns a value between −1 and +1. A coefficient of +1 represents a perfect prediction, 0 no better than random prediction and −1 indicates total disagreement between prediction and observation

Implementing the MCC score as loss function:$$\frac{ \sum_{i}^{n} p_{ i }g_{ i } * \sum_{i}^{n} 1-p_{ i } 1-g_{ i } + \sum_{i}^{n} 1-p_{ i } g_{ i } * \sum_{i}^{n} p_{ i } 1-g_{ i }}{ \sqrt{ (\sum_{i}^{n} p_{ i } g_{ i } + \sum_{i}^{n} 1-p_{ i } g_{ i }) * (\sum_{i}^{n} p_{ i } g_{ i } + \sum_{i}^{n} p_{ i } 1-g_{ i }) * (\sum_{i}^{n} 1-p_{ i } g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) * (\sum_{i}^{n} p_{ i } 1-g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) } }$$

where p_i is the prediction for pixel i and g_i the corresponding ground truth pixel and gamma is the smoothing factor.

class MCCLoss[source]

MCCLoss(eps=1e-07, smooth=0.0, *args, **kwargs) :: BaseLoss

Computes the MCC loss.

For this loss to work best, the input should be in range 0-1, e.g. enforced through a sigmoid or softmax.
Note that PyTorch optimizers minimize a loss. So the loss is subtracted from 1.
While the MCC score can become negative the MCC loss should not go below 0.


Args:
    inherited from `BaseLoss`
    targ:    A tensor of shape [B, C, D, H, W].
    pred:    A tensor of shape [B, C, D, H, W]. Corresponds to
             the raw output or logits of the model.
    weights: A list ot tuple, giving weights for each class or None
    avg:     'macro' computes loss for each B x C and averages the losses
             'micro' computes loss for each B and acverages the losses

    Unique for `MCCLoss`
    smooth:  Smoothing factor, default is 1.
    eps:     Added for numerical stability.

Returns:
    mmc_loss: loss based on Matthews correlation coefficient
MCCLoss()(pred_1, mask), MCCLoss()(pred_2, mask), MCCLoss()(pred_3, mask)
(tensor(0.), tensor(2.), tensor(0.5342))
class SoftMCCLoss(MCCLoss):
    """
    Same as MCCLoss but can handle float values in the mask (e.g. from MixUp). 
    Example: 
        t = torch.randn(2,5); t
        >>> tensor([[ 0.9113, -0.7525, -2.1771, -0.2420, -0.2245],
                    [ 1.9503, -1.2903,  0.1201,  0.2830,  0.0473]])
                   
        MCCLoss().make_binary(t, 1)
        >>> tensor([[0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.]])
        
        SoftMCCLoss().soft_binary(t, 0)
        >>> tensor([[0.9113, 0.0000, 0.0000, 0.0000, 0.0000],
                    [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
    """
    
    def soft_binary(self, t, set_to_one):
        return torch.where(t.gt(set_to_one - 0.49) != t.gt(set_to_one + 0.49), 
                           t.float(), 
                           tensor(0.).to(t.device) if set_to_one > 0 else tensor(1.).to(t.device))
    
    def to_one_hot(self, target:Tensor):
        target = target.squeeze(1) # remove the solitary color channel (if there is one) and set type to int64
        one_hot = [self.soft_binary(target, set_to_one=i) for i in range(0, self.n_classes)]
        return torch.stack(one_hot, 1)
SoftMCCLoss()(pred_1, mask)
tensor(0.)
class WeightedMCCLoss(MCCLoss):
    """
    Weighted version of MCCLoss 
    Note that class specific weight can still be added through `weights` during initialization. 
    
    Args: 
        alpha: weight for true positives
        beta: weight for false positives
        gamma: weight for false negatives
        delta: weight for true negatives
    """

    def __init__(self, alpha=1., beta=1., gamma=1., delta=1.,*args, **kwargs):
        store_attr()
        super().__init__(*args, **kwargs)

    def compute_loss(self, input: Tensor, target: Tensor, dims):
        tps = torch.sum(input * target, dims) 
        fps = torch.sum(input * (1 - target), dims) 
        fns = torch.sum((1 - input) * target, dims)
        tns = torch.sum((1 - input) * (1-target), dims)

        numerator = (tps * tns - fps * fns) + self.smooth
        denominator =  ((tps * self.alpha + fps * self.beta) * (tps * self.alpha + fns * self.gamma) * (fps * self.beta + tns * self.delta) * (tns * self.delta + fns * self.gamma) + self.eps)**0.5 + self.smooth

        mcc_loss = numerator / (denominator)

        return 1-mcc_loss
    
    def activation(self, x): 
        return x
WeightedMCCLoss()(pred_1, mask)
tensor(0.)