generate images for testing
def one_hot(mask):
return torch.cat([(mask == i).float() for i in mask.unique()], 1)
mask = torch.zeros(1, 1, 5, 10, 10)
mask[:,:, 2:4, 3:7, 3:7] = 1
pred_1 = one_hot(mask)
pred_2 = one_hot(mask).flip(1) # prediction completely off
pred_3 = one_hot(mask).flip(2) # prediction partly off
show_images_3d(mask) # show_images_3d assumes [B x D x H x W]
show_images_3d(pred_1.squeeze())
show_images_3d(pred_2.squeeze())
show_images_3d(pred_3.squeeze())
DiceLoss()(pred_1, mask), DiceLoss()(pred_2, one_hot(mask)), DiceLoss()(pred_3, one_hot(mask))
DiceLoss2()(pred_1, mask)#, DiceLoss2()(pred_2, mask), DiceLoss2()(pred_3, mask)
MCC loss
From Wikipedia (https://en.wikipedia.org/wiki/Matthews_correlation_coefficient):
The coefficient takes into account true and false positives and negatives and is generally regarded as a balanced measure which can be used even if the classes are of very different sizes The MCC is in essence a correlation coefficient between the observed and predicted binary classifications; it returns a value between −1 and +1. A coefficient of +1 represents a perfect prediction, 0 no better than random prediction and −1 indicates total disagreement between prediction and observation
Implementing the MCC score as loss function:$$\frac{ \sum_{i}^{n} p_{ i }g_{ i } * \sum_{i}^{n} 1-p_{ i } 1-g_{ i } + \sum_{i}^{n} 1-p_{ i } g_{ i } * \sum_{i}^{n} p_{ i } 1-g_{ i }}{ \sqrt{ (\sum_{i}^{n} p_{ i } g_{ i } + \sum_{i}^{n} 1-p_{ i } g_{ i }) * (\sum_{i}^{n} p_{ i } g_{ i } + \sum_{i}^{n} p_{ i } 1-g_{ i }) * (\sum_{i}^{n} 1-p_{ i } g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) * (\sum_{i}^{n} p_{ i } 1-g_{ i } + \sum_{i}^{n} 1-p_{ i } 1-g_{ i }) } }$$
where p_i is the prediction for pixel i and g_i the corresponding ground truth pixel and gamma is the smoothing factor.
MCCLoss()(pred_1, mask), MCCLoss()(pred_2, mask), MCCLoss()(pred_3, mask)
class SoftMCCLoss(MCCLoss):
"""
Same as MCCLoss but can handle float values in the mask (e.g. from MixUp).
Example:
t = torch.randn(2,5); t
>>> tensor([[ 0.9113, -0.7525, -2.1771, -0.2420, -0.2245],
[ 1.9503, -1.2903, 0.1201, 0.2830, 0.0473]])
MCCLoss().make_binary(t, 1)
>>> tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
SoftMCCLoss().soft_binary(t, 0)
>>> tensor([[0.9113, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
"""
def soft_binary(self, t, set_to_one):
return torch.where(t.gt(set_to_one - 0.49) != t.gt(set_to_one + 0.49),
t.float(),
tensor(0.).to(t.device) if set_to_one > 0 else tensor(1.).to(t.device))
def to_one_hot(self, target:Tensor):
target = target.squeeze(1) # remove the solitary color channel (if there is one) and set type to int64
one_hot = [self.soft_binary(target, set_to_one=i) for i in range(0, self.n_classes)]
return torch.stack(one_hot, 1)
SoftMCCLoss()(pred_1, mask)
class WeightedMCCLoss(MCCLoss):
"""
Weighted version of MCCLoss
Note that class specific weight can still be added through `weights` during initialization.
Args:
alpha: weight for true positives
beta: weight for false positives
gamma: weight for false negatives
delta: weight for true negatives
"""
def __init__(self, alpha=1., beta=1., gamma=1., delta=1.,*args, **kwargs):
store_attr()
super().__init__(*args, **kwargs)
def compute_loss(self, input: Tensor, target: Tensor, dims):
tps = torch.sum(input * target, dims)
fps = torch.sum(input * (1 - target), dims)
fns = torch.sum((1 - input) * target, dims)
tns = torch.sum((1 - input) * (1-target), dims)
numerator = (tps * tns - fps * fns) + self.smooth
denominator = ((tps * self.alpha + fps * self.beta) * (tps * self.alpha + fns * self.gamma) * (fps * self.beta + tns * self.delta) * (tns * self.delta + fns * self.gamma) + self.eps)**0.5 + self.smooth
mcc_loss = numerator / (denominator)
return 1-mcc_loss
def activation(self, x):
return x
WeightedMCCLoss()(pred_1, mask)