|
|
|
@ -33,9 +33,8 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
if weights: |
|
|
|
|
model.load(weights) |
|
|
|
|
|
|
|
|
|
pretrained = self.args.pretrained |
|
|
|
|
for m in model.modules(): |
|
|
|
|
if not pretrained and hasattr(m, 'reset_parameters'): |
|
|
|
|
if not self.args.pretrained and hasattr(m, 'reset_parameters'): |
|
|
|
|
m.reset_parameters() |
|
|
|
|
if isinstance(m, torch.nn.Dropout) and self.args.dropout: |
|
|
|
|
m.p = self.args.dropout # set dropout |
|
|
|
@ -61,8 +60,7 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
elif model.endswith('.yaml'): |
|
|
|
|
self.model = self.get_model(cfg=model) |
|
|
|
|
elif model in torchvision.models.__dict__: |
|
|
|
|
pretrained = True |
|
|
|
|
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) |
|
|
|
|
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None) |
|
|
|
|
else: |
|
|
|
|
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') |
|
|
|
|
ClassificationModel.reshape_outputs(self.model, self.data['nc']) |
|
|
|
|