You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Dec 23, 2024. It is now read-only.
defdice_loss(logits, true, eps=1e-7):
"""Computes the Sørensen–Dice loss. Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we return the negated dice loss. Args: true: a tensor of shape [B, 1, H, W]. logits: a tensor of shape [B, C, H, W]. Corresponds to the raw output or logits of the model. eps: added to the denominator for numerical stability. Returns: dice_loss: the Sørensen–Dice loss. """num_classes=logits.shape[1]
ifnum_classes==1:
true_1_hot=torch.eye(num_classes+1)[true.squeeze(1)]
true_1_hot=true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f=true_1_hot[:, 0:1, :, :]
true_1_hot_s=true_1_hot[:, 1:2, :, :]
true_1_hot=torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob=torch.sigmoid(logits)
neg_prob=1-pos_probprobas=torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot=torch.eye(num_classes)[true.squeeze(1)]
true_1_hot=true_1_hot.permute(0, 3, 1, 2).float()
probas=F.softmax(logits, dim=1)
true_1_hot=true_1_hot.type(logits.type())
dims= (0,) +tuple(range(2, true.ndimension()))
intersection=torch.sum(probas*true_1_hot, dims)
cardinality=torch.sum(probas+true_1_hot, dims)
dice_loss= (2.0*intersection/ (cardinality+eps)).mean()
return1-dice_loss
when the shape of input true is (B, H, W), dims is (0, 2). However, according to the definition of dice loss $1-\frac{2*I}{U}$, dims should be (2, 3), so I think the code here may be:
dims=tuple(range(2, logits.ndimension()))
Would you please double check or see if I misunderstood?
The text was updated successfully, but these errors were encountered:
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
The function of dice loss in the code:
intersection
andcardinality
are calculated bywhen the shape of input$1-\frac{2*I}{U}$ ,
true
is (B, H, W),dims
is (0, 2). However, according to the definition of dice lossdims
should be (2, 3), so I think the code here may be:Would you please double check or see if I misunderstood?
The text was updated successfully, but these errors were encountered: