@ -19,7 +19,8 @@ class VarifocalLoss(nn.Module):
""" Initialize the VarifocalLoss class. """
super ( ) . __init__ ( )
def forward ( self , pred_score , gt_score , label , alpha = 0.75 , gamma = 2.0 ) :
@staticmethod
def forward ( pred_score , gt_score , label , alpha = 0.75 , gamma = 2.0 ) :
""" Computes varfocal loss. """
weight = alpha * pred_score . sigmoid ( ) . pow ( gamma ) * ( 1 - label ) + gt_score * label
with torch . cuda . amp . autocast ( enabled = False ) :
@ -28,14 +29,14 @@ class VarifocalLoss(nn.Module):
return loss
# Losses
class FocalLoss ( nn . Module ) :
""" Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5). """
def __init__ ( self , ) :
super ( ) . __init__ ( )
def forward ( self , pred , label , gamma = 1.5 , alpha = 0.25 ) :
@staticmethod
def forward ( pred , label , gamma = 1.5 , alpha = 0.25 ) :
""" Calculates and updates confusion matrix for object detection/classification tasks. """
loss = F . binary_cross_entropy_with_logits ( pred , label , reduction = ' none ' )
# p_t = torch.exp(-loss)
@ -89,6 +90,7 @@ class BboxLoss(nn.Module):
class KeypointLoss ( nn . Module ) :
""" Criterion class for computing training losses. """
def __init__ ( self , sigmas ) - > None :
super ( ) . __init__ ( )
@ -103,8 +105,8 @@ class KeypointLoss(nn.Module):
return kpt_loss_factor * ( ( 1 - torch . exp ( - e ) ) * kpt_mask ) . mean ( )
# Criterion class for computing Detection training losses
class v8DetectionLoss :
""" Criterion class for computing training losses. """
def __init__ ( self , model ) : # model must be de-paralleled
@ -199,8 +201,8 @@ class v8DetectionLoss:
return loss . sum ( ) * batch_size , loss . detach ( ) # loss(box, cls, dfl)
# Criterion class for computing training losses
class v8SegmentationLoss ( v8DetectionLoss ) :
""" Criterion class for computing training losses. """
def __init__ ( self , model ) : # model must be de-paralleled
super ( ) . __init__ ( model )
@ -294,8 +296,8 @@ class v8SegmentationLoss(v8DetectionLoss):
return ( crop_mask ( loss , xyxy ) . mean ( dim = ( 1 , 2 ) ) / area ) . mean ( )
# Criterion class for computing training losses
class v8PoseLoss ( v8DetectionLoss ) :
""" Criterion class for computing training losses. """
def __init__ ( self , model ) : # model must be de-paralleled
super ( ) . __init__ ( model )
@ -374,7 +376,8 @@ class v8PoseLoss(v8DetectionLoss):
return loss . sum ( ) * batch_size , loss . detach ( ) # loss(box, cls, dfl)
def kpts_decode ( self , anchor_points , pred_kpts ) :
@staticmethod
def kpts_decode ( anchor_points , pred_kpts ) :
""" Decodes predicted keypoints to image coordinates. """
y = pred_kpts . clone ( )
y [ . . . , : 2 ] * = 2.0
@ -384,6 +387,7 @@ class v8PoseLoss(v8DetectionLoss):
class v8ClassificationLoss :
""" Criterion class for computing training losses. """
def __call__ ( self , preds , batch ) :
""" Compute the classification loss between predictions and true labels. """