Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mmdet/models/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def py_focal_loss_with_prob(pred,
pred (torch.Tensor): The prediction probability with shape (N, C),
C is the number of classes.
target (torch.Tensor): The learning label of the prediction.
The target shape support (N,C) or (N,), (N,C) means one-hot form.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
Expand All @@ -82,9 +83,10 @@ def py_focal_loss_with_prob(pred,
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
if pred.dim() != target.dim():
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]

target = target.type_as(pred)
pt = (1 - pred) * target + pred * (1 - target)
Expand Down Expand Up @@ -204,6 +206,8 @@ def forward(self,
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
The target shape support (N,C) or (N,), (N,C) means
one-hot form.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
Expand All @@ -222,7 +226,10 @@ def forward(self,
if self.activated:
calculate_loss_func = py_focal_loss_with_prob
else:
if torch.cuda.is_available() and pred.is_cuda:
if pred.dim() == target.dim():
# this means that target is already in One-Hot form.
calculate_loss_func = py_sigmoid_focal_loss
elif torch.cuda.is_available() and pred.is_cuda:
calculate_loss_func = sigmoid_focal_loss
else:
num_classes = pred.size(1)
Expand Down
59 changes: 56 additions & 3 deletions mmdet/models/losses/gfocal_loss.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmdet.models.losses.utils import weighted_loss
from mmdet.registry import MODELS
from .utils import weighted_loss


@weighted_loss
Expand Down Expand Up @@ -50,6 +53,47 @@ def quality_focal_loss(pred, target, beta=2.0):
return loss


@weighted_loss
def quality_focal_loss_tensor_target(pred, target, beta=2.0, activated=False):
"""`QualityFocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning target of the iou-aware
classification score with shape (N, C), C is the number of classes.
beta (float): The beta parameter for calculating the modulating factor.
Defaults to 2.0.
activated (bool): Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
"""
# pred and target should be of the same size
assert pred.size() == target.size()
if activated:
pred_sigmoid = pred
loss_function = F.binary_cross_entropy
else:
pred_sigmoid = pred.sigmoid()
loss_function = F.binary_cross_entropy_with_logits

scale_factor = pred_sigmoid
target = target.type_as(pred)

zerolabel = scale_factor.new_zeros(pred.shape)
loss = loss_function(
pred, zerolabel, reduction='none') * scale_factor.pow(beta)

pos = (target != 0)
scale_factor = target[pos] - pred_sigmoid[pos]
loss[pos] = loss_function(
pred[pos], target[pos],
reduction='none') * scale_factor.abs().pow(beta)

loss = loss.sum(dim=1, keepdim=False)
return loss


@weighted_loss
def quality_focal_loss_with_prob(pred, target, beta=2.0):
r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Expand Down Expand Up @@ -166,8 +210,11 @@ def forward(self,
pred (torch.Tensor): Predicted joint representation of
classification and quality (IoU) estimation with shape (N, C),
C is the number of classes.
target (tuple([torch.Tensor])): Target category label with shape
(N,) and target quality label with shape (N,).
target (Union(tuple([torch.Tensor]),Torch.Tensor)): The type is
tuple, it should be included Target category label with
shape (N,) and target quality label with shape (N,).The type
is torch.Tensor, the target should be one-hot form with
soft weights.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
Expand All @@ -184,6 +231,12 @@ def forward(self,
calculate_loss_func = quality_focal_loss_with_prob
else:
calculate_loss_func = quality_focal_loss
if isinstance(target, torch.Tensor):
# the target shape with (N,C) or (N,C,...), which means
# the target is one-hot form with soft weights.
calculate_loss_func = partial(
quality_focal_loss_tensor_target, activated=self.activated)

loss_cls = self.loss_weight * calculate_loss_func(
pred,
target,
Expand Down
61 changes: 59 additions & 2 deletions tests/test_models/test_losses/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn.functional as F
from mmengine.utils import digit_version

from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, DiceLoss,
Expand Down Expand Up @@ -29,7 +30,7 @@ def test_iou_type_loss_zeros_weight(loss_class):
@pytest.mark.parametrize('loss_class', [
BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss,
EIoULoss, FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss,
GaussianFocalLoss, GIoULoss, IoULoss, L1Loss, QualityFocalLoss,
GaussianFocalLoss, GIoULoss, QualityFocalLoss, IoULoss, L1Loss,
VarifocalLoss, GHMR, GHMC, SmoothL1Loss, KnowledgeDistillationKLDivLoss,
DiceLoss
])
Expand All @@ -46,6 +47,26 @@ def test_loss_with_reduction_override(loss_class):
pred, target, weight, reduction_override=reduction_override)


@pytest.mark.parametrize('loss_class', [QualityFocalLoss])
@pytest.mark.parametrize('activated', [False, True])
def test_QualityFocalLoss_Loss(loss_class, activated):
input_shape = (4, 5)
pred = torch.rand(input_shape)
label = torch.Tensor([0, 1, 2, 0]).long()
quality_label = torch.rand(input_shape[0])

original_loss = loss_class(activated=activated)(pred,
(label, quality_label))
assert isinstance(original_loss, torch.Tensor)

target = torch.nn.functional.one_hot(label, 5)
target = target * quality_label.reshape(input_shape[0], 1)

new_loss = loss_class(activated=activated)(pred, target)
assert isinstance(new_loss, torch.Tensor)
assert new_loss == original_loss


@pytest.mark.parametrize('loss_class', [
IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, EIoULoss, MSELoss,
L1Loss, SmoothL1Loss, BalancedL1Loss
Expand Down Expand Up @@ -86,7 +107,7 @@ def test_regression_losses(loss_class, input_shape):
assert isinstance(loss, torch.Tensor)


@pytest.mark.parametrize('loss_class', [FocalLoss, CrossEntropyLoss])
@pytest.mark.parametrize('loss_class', [CrossEntropyLoss])
@pytest.mark.parametrize('input_shape', [(10, 5), (0, 5)])
def test_classification_losses(loss_class, input_shape):
if input_shape[0] == 0 and digit_version(
Expand Down Expand Up @@ -124,6 +145,42 @@ def test_classification_losses(loss_class, input_shape):
assert isinstance(loss, torch.Tensor)


@pytest.mark.parametrize('loss_class', [FocalLoss])
@pytest.mark.parametrize('input_shape', [(10, 5), (3, 5, 40, 40)])
def test_FocalLoss_loss(loss_class, input_shape):
pred = torch.rand(input_shape)
target = torch.randint(0, 5, (input_shape[0], ))
if len(input_shape) == 4:
B, N, W, H = input_shape
target = F.one_hot(torch.randint(0, 5, (B * W * H, )),
5).reshape(B, W, H, N).permute(0, 3, 1, 2)

# Test loss forward
loss = loss_class()(pred, target)
assert isinstance(loss, torch.Tensor)

# Test loss forward with reduction_override
loss = loss_class()(pred, target, reduction_override='mean')
assert isinstance(loss, torch.Tensor)

# Test loss forward with avg_factor
loss = loss_class()(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)

with pytest.raises(ValueError):
# loss can evaluate with avg_factor only if
# reduction is None, 'none' or 'mean'.
reduction_override = 'sum'
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)

# Test loss forward with avg_factor and reduction
for reduction_override in [None, 'none', 'mean']:
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
assert isinstance(loss, torch.Tensor)


@pytest.mark.parametrize('loss_class', [GHMR])
@pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
def test_GHMR_loss(loss_class, input_shape):
Expand Down