You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.4 KiB
67 lines
2.4 KiB
import copy |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
from ppcls.utils import logger |
|
|
|
from .celoss import CELoss, MixCELoss |
|
from .googlenetloss import GoogLeNetLoss |
|
from .centerloss import CenterLoss |
|
from .emlloss import EmlLoss |
|
from .msmloss import MSMLoss |
|
from .npairsloss import NpairsLoss |
|
from .trihardloss import TriHardLoss |
|
from .triplet import TripletLoss, TripletLossV2 |
|
from .supconloss import SupConLoss |
|
from .pairwisecosface import PairwiseCosface |
|
from .dmlloss import DMLLoss |
|
from .distanceloss import DistanceLoss |
|
|
|
from .distillationloss import DistillationCELoss |
|
from .distillationloss import DistillationGTCELoss |
|
from .distillationloss import DistillationDMLLoss |
|
from .distillationloss import DistillationDistanceLoss |
|
from .distillationloss import DistillationRKDLoss |
|
from .multilabelloss import MultiLabelLoss |
|
|
|
from .deephashloss import DSHSDLoss, LCDSHLoss |
|
|
|
|
|
class CombinedLoss(nn.Layer): |
|
def __init__(self, config_list): |
|
super().__init__() |
|
self.loss_func = [] |
|
self.loss_weight = [] |
|
assert isinstance(config_list, list), ( |
|
'operator config should be a list') |
|
for config in config_list: |
|
assert isinstance(config, |
|
dict) and len(config) == 1, "yaml format error" |
|
name = list(config)[0] |
|
param = config[name] |
|
assert "weight" in param, "weight must be in param, but param just contains {}".format( |
|
param.keys()) |
|
self.loss_weight.append(param.pop("weight")) |
|
self.loss_func.append(eval(name)(**param)) |
|
|
|
def __call__(self, input, batch): |
|
loss_dict = {} |
|
# just for accelerate classification traing speed |
|
if len(self.loss_func) == 1: |
|
loss = self.loss_func[0](input, batch) |
|
loss_dict.update(loss) |
|
loss_dict["loss"] = list(loss.values())[0] |
|
else: |
|
for idx, loss_func in enumerate(self.loss_func): |
|
loss = loss_func(input, batch) |
|
weight = self.loss_weight[idx] |
|
loss = {key: loss[key] * weight for key in loss} |
|
loss_dict.update(loss) |
|
loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) |
|
return loss_dict |
|
|
|
|
|
def build_loss(config): |
|
module_class = CombinedLoss(copy.deepcopy(config)) |
|
# logger.debug("build loss {} success.".format(module_class)) |
|
return module_class
|
|
|