You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.8 KiB
83 lines
2.8 KiB
2 years ago
|
---
|
||
|
comments: true
|
||
1 year ago
|
description: Discover how to customize and extend base Ultralytics YOLO Trainer engines. Support your custom model and dataloader by overriding built-in functions.
|
||
|
keywords: Ultralytics, YOLO, trainer engines, BaseTrainer, DetectionTrainer, customizing trainers, extending trainers, custom model, custom dataloader
|
||
2 years ago
|
---
|
||
|
|
||
1 year ago
|
Both the Ultralytics YOLO command-line and Python interfaces are simply a high-level abstraction on the base engine executors. Let's take a look at the Trainer engine.
|
||
2 years ago
|
|
||
|
## BaseTrainer
|
||
|
|
||
1 year ago
|
BaseTrainer contains the generic boilerplate training routine. It can be customized for any task based over overriding the required functions or operations as long the as correct formats are followed. For example, you can support your own custom model and dataloader by just overriding these functions:
|
||
2 years ago
|
|
||
|
* `get_model(cfg, weights)` - The function that builds the model to be trained
|
||
1 year ago
|
* `get_dataloader()` - The function that builds the dataloader More details and source code can be found in [`BaseTrainer` Reference](../reference/engine/trainer.md)
|
||
2 years ago
|
|
||
|
## DetectionTrainer
|
||
2 years ago
|
|
||
2 years ago
|
Here's how you can use the YOLOv8 `DetectionTrainer` and customize it.
|
||
2 years ago
|
|
||
2 years ago
|
```python
|
||
1 year ago
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||
2 years ago
|
|
||
|
trainer = DetectionTrainer(overrides={...})
|
||
|
trainer.train()
|
||
2 years ago
|
trained_model = trainer.best # get best model
|
||
2 years ago
|
```
|
||
|
|
||
|
### Customizing the DetectionTrainer
|
||
2 years ago
|
|
||
1 year ago
|
Let's customize the trainer **to train a custom detection model** that is not supported directly. You can do this by simply overloading the existing the `get_model` functionality:
|
||
2 years ago
|
|
||
2 years ago
|
```python
|
||
1 year ago
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||
2 years ago
|
|
||
2 years ago
|
|
||
|
class CustomTrainer(DetectionTrainer):
|
||
|
def get_model(self, cfg, weights):
|
||
|
...
|
||
|
|
||
2 years ago
|
|
||
2 years ago
|
trainer = CustomTrainer(overrides={...})
|
||
|
trainer.train()
|
||
|
```
|
||
2 years ago
|
|
||
2 years ago
|
You now realize that you need to customize the trainer further to:
|
||
|
|
||
2 years ago
|
* Customize the `loss function`.
|
||
2 years ago
|
* Add `callback` that uploads model to your Google Drive after every 10 `epochs`
|
||
|
Here's how you can do it:
|
||
2 years ago
|
|
||
|
```python
|
||
1 year ago
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||
2 years ago
|
from ultralytics.nn.tasks import DetectionModel
|
||
|
|
||
2 years ago
|
|
||
|
class MyCustomModel(DetectionModel):
|
||
2 years ago
|
def init_criterion(self):
|
||
|
...
|
||
2 years ago
|
|
||
2 years ago
|
|
||
|
class CustomTrainer(DetectionTrainer):
|
||
|
def get_model(self, cfg, weights):
|
||
2 years ago
|
return MyCustomModel(...)
|
||
2 years ago
|
|
||
2 years ago
|
|
||
2 years ago
|
# callback to upload model weights
|
||
|
def log_model(trainer):
|
||
|
last_weight_path = trainer.last
|
||
|
...
|
||
|
|
||
2 years ago
|
|
||
2 years ago
|
trainer = CustomTrainer(overrides={...})
|
||
2 years ago
|
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callback
|
||
2 years ago
|
trainer.train()
|
||
|
```
|
||
|
|
||
2 years ago
|
To know more about Callback triggering events and entry point, checkout our [Callbacks Guide](callbacks.md)
|
||
2 years ago
|
|
||
|
## Other engine components
|
||
2 years ago
|
|
||
2 years ago
|
There are other components that can be customized similarly like `Validators` and `Predictors`
|
||
1 year ago
|
See Reference section for more information on these.
|