|
|
|
@ -1,16 +1,18 @@ |
|
|
|
|
# Ultralytics YOLO 🚀, GPL-3.0 license |
|
|
|
|
import json |
|
|
|
|
import signal |
|
|
|
|
import sys |
|
|
|
|
from pathlib import Path |
|
|
|
|
from time import sleep |
|
|
|
|
from time import sleep, time |
|
|
|
|
|
|
|
|
|
import requests |
|
|
|
|
|
|
|
|
|
from ultralytics import __version__ |
|
|
|
|
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request |
|
|
|
|
from ultralytics.yolo.utils import is_colab, threaded |
|
|
|
|
|
|
|
|
|
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' |
|
|
|
|
from ultralytics.yolo.utils import is_colab, threaded, LOGGER, emojis, PREFIX |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params |
|
|
|
|
|
|
|
|
|
AGENT_NAME = (f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local") |
|
|
|
|
session = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -19,23 +21,37 @@ class HubTrainingSession: |
|
|
|
|
def __init__(self, model_id, auth): |
|
|
|
|
self.agent_id = None # identifies which instance is communicating with server |
|
|
|
|
self.model_id = model_id |
|
|
|
|
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' |
|
|
|
|
self.api_url = f"{HUB_API_ROOT}/v1/models/{model_id}" |
|
|
|
|
self.auth_header = auth.get_auth_header() |
|
|
|
|
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) |
|
|
|
|
self.t = {} # rate limit timers (seconds) |
|
|
|
|
self.metrics_queue = {} # metrics queue |
|
|
|
|
self.alive = True # for heartbeats |
|
|
|
|
self._rate_limits = {"metrics": 3.0, "ckpt": 900.0, "heartbeat": 300.0} # rate limits (seconds) |
|
|
|
|
self._timers = {} # rate limit timers (seconds) |
|
|
|
|
self._metrics_queue = {} # metrics queue |
|
|
|
|
self.model = self._get_model() |
|
|
|
|
self._heartbeats() # start heartbeats |
|
|
|
|
signal.signal(signal.SIGTERM, self.shutdown) # register the shutdown function to be called on exit |
|
|
|
|
signal.signal(signal.SIGINT, self.shutdown) |
|
|
|
|
self._start_heartbeat() # start heartbeats |
|
|
|
|
self._register_signal_handlers() |
|
|
|
|
|
|
|
|
|
def shutdown(self, *args): # noqa |
|
|
|
|
self.alive = False # stop heartbeats |
|
|
|
|
def _register_signal_handlers(self): |
|
|
|
|
signal.signal(signal.SIGTERM, self._handle_signal) |
|
|
|
|
signal.signal(signal.SIGINT, self._handle_signal) |
|
|
|
|
|
|
|
|
|
def _handle_signal(self, signum, frame): |
|
|
|
|
""" |
|
|
|
|
Prevent heartbeats from being sent on Colab after kill. |
|
|
|
|
This method does not use frame, it is included as it is |
|
|
|
|
passed by signal. |
|
|
|
|
""" |
|
|
|
|
if self.alive is True: |
|
|
|
|
LOGGER.info(f"{PREFIX}Kill signal received! ❌") |
|
|
|
|
self._stop_heartbeat() |
|
|
|
|
sys.exit(signum) |
|
|
|
|
|
|
|
|
|
def _stop_heartbeat(self): |
|
|
|
|
"""End the heartbeat loop""" |
|
|
|
|
self.alive = False |
|
|
|
|
|
|
|
|
|
def upload_metrics(self): |
|
|
|
|
payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"} |
|
|
|
|
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2) |
|
|
|
|
payload = {"metrics": self._metrics_queue.copy(), "type": "metrics"} |
|
|
|
|
smart_request(f"{self.api_url}", json=payload, headers=self.auth_header, code=2) |
|
|
|
|
|
|
|
|
|
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): |
|
|
|
|
# Upload a model to HUB |
|
|
|
@ -44,25 +60,29 @@ class HubTrainingSession: |
|
|
|
|
with open(weights, "rb") as f: |
|
|
|
|
file = f.read() |
|
|
|
|
if final: |
|
|
|
|
smart_request(f'{self.api_url}/upload', |
|
|
|
|
data={ |
|
|
|
|
"epoch": epoch, |
|
|
|
|
"type": "final", |
|
|
|
|
"map": map}, |
|
|
|
|
files={"best.pt": file}, |
|
|
|
|
headers=self.auth_header, |
|
|
|
|
retry=10, |
|
|
|
|
timeout=3600, |
|
|
|
|
code=4) |
|
|
|
|
smart_request( |
|
|
|
|
f"{self.api_url}/upload", |
|
|
|
|
data={ |
|
|
|
|
"epoch": epoch, |
|
|
|
|
"type": "final", |
|
|
|
|
"map": map}, |
|
|
|
|
files={"best.pt": file}, |
|
|
|
|
headers=self.auth_header, |
|
|
|
|
retry=10, |
|
|
|
|
timeout=3600, |
|
|
|
|
code=4, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
smart_request(f'{self.api_url}/upload', |
|
|
|
|
data={ |
|
|
|
|
"epoch": epoch, |
|
|
|
|
"type": "epoch", |
|
|
|
|
"isBest": bool(is_best)}, |
|
|
|
|
headers=self.auth_header, |
|
|
|
|
files={"last.pt": file}, |
|
|
|
|
code=3) |
|
|
|
|
smart_request( |
|
|
|
|
f"{self.api_url}/upload", |
|
|
|
|
data={ |
|
|
|
|
"epoch": epoch, |
|
|
|
|
"type": "epoch", |
|
|
|
|
"isBest": bool(is_best)}, |
|
|
|
|
headers=self.auth_header, |
|
|
|
|
files={"last.pt": file}, |
|
|
|
|
code=3, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def _get_model(self): |
|
|
|
|
# Returns model from database by id |
|
|
|
@ -70,31 +90,131 @@ class HubTrainingSession: |
|
|
|
|
headers = self.auth_header |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
r = smart_request(api_url, method="get", headers=headers, thread=False, code=0) |
|
|
|
|
data = r.json().get("data", None) |
|
|
|
|
if not data: |
|
|
|
|
return |
|
|
|
|
assert data['data'], 'ERROR: Dataset may still be processing. Please wait a minute and try again.' # RF fix |
|
|
|
|
response = smart_request(api_url, method="get", headers=headers, thread=False, code=0) |
|
|
|
|
data = response.json().get("data", None) |
|
|
|
|
|
|
|
|
|
if data.get("status", None) == "trained": |
|
|
|
|
raise ValueError( |
|
|
|
|
emojis(f"Model trained. View model at https://hub.ultralytics.com/models/{self.model_id} 🚀")) |
|
|
|
|
|
|
|
|
|
if not data.get("data", None): |
|
|
|
|
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix |
|
|
|
|
self.model_id = data["id"] |
|
|
|
|
|
|
|
|
|
# TODO: restore when server keys when dataset URL and GPU train is working |
|
|
|
|
|
|
|
|
|
self.train_args = { |
|
|
|
|
"batch": data["batch_size"], |
|
|
|
|
"epochs": data["epochs"], |
|
|
|
|
"imgsz": data["imgsz"], |
|
|
|
|
"patience": data["patience"], |
|
|
|
|
"device": data["device"], |
|
|
|
|
"cache": data["cache"], |
|
|
|
|
"data": data["data"]} |
|
|
|
|
|
|
|
|
|
self.input_file = data.get("cfg", data["weights"]) |
|
|
|
|
|
|
|
|
|
# hack for yolov5 cfg adds u |
|
|
|
|
if "cfg" in data and "yolov5" in data["cfg"]: |
|
|
|
|
self.input_file = data["cfg"].replace(".yaml", "u.yaml") |
|
|
|
|
|
|
|
|
|
return data |
|
|
|
|
except requests.exceptions.ConnectionError as e: |
|
|
|
|
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e |
|
|
|
|
raise ConnectionRefusedError("ERROR: The HUB server is not online. Please try again later.") from e |
|
|
|
|
except Exception: |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
def check_disk_space(self): |
|
|
|
|
if not check_dataset_disk_space(self.model['data']): |
|
|
|
|
if not check_dataset_disk_space(self.model["data"]): |
|
|
|
|
raise MemoryError("Not enough disk space") |
|
|
|
|
|
|
|
|
|
def register_callbacks(self, trainer): |
|
|
|
|
trainer.add_callback("on_pretrain_routine_end", self.on_pretrain_routine_end) |
|
|
|
|
trainer.add_callback("on_fit_epoch_end", self.on_fit_epoch_end) |
|
|
|
|
trainer.add_callback("on_model_save", self.on_model_save) |
|
|
|
|
trainer.add_callback("on_train_end", self.on_train_end) |
|
|
|
|
|
|
|
|
|
def on_pretrain_routine_end(self, trainer): |
|
|
|
|
""" |
|
|
|
|
Start timer for upload rate limit. |
|
|
|
|
This method does not use trainer. It is passed to all callbacks by default. |
|
|
|
|
""" |
|
|
|
|
# Start timer for upload rate limit |
|
|
|
|
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀") |
|
|
|
|
self._timers = {"metrics": time(), "ckpt": time()} # start timer on self.rate_limit |
|
|
|
|
|
|
|
|
|
def on_fit_epoch_end(self, trainer): |
|
|
|
|
# Upload metrics after val end |
|
|
|
|
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics} |
|
|
|
|
|
|
|
|
|
if trainer.epoch == 0: |
|
|
|
|
model_info = { |
|
|
|
|
"model/parameters": get_num_params(trainer.model), |
|
|
|
|
"model/GFLOPs": round(get_flops(trainer.model), 3), |
|
|
|
|
"model/speed(ms)": round(trainer.validator.speed[1], 3)} |
|
|
|
|
all_plots = {**all_plots, **model_info} |
|
|
|
|
self._metrics_queue[trainer.epoch] = json.dumps(all_plots) |
|
|
|
|
if time() - self._timers["metrics"] > self._rate_limits["metrics"]: |
|
|
|
|
self.upload_metrics() |
|
|
|
|
self._timers["metrics"] = time() # reset timer |
|
|
|
|
self._metrics_queue = {} # reset queue |
|
|
|
|
|
|
|
|
|
def on_model_save(self, trainer): |
|
|
|
|
# Upload checkpoints with rate limiting |
|
|
|
|
is_best = trainer.best_fitness == trainer.fitness |
|
|
|
|
if time() - self._timers["ckpt"] > self._rate_limits["ckpt"]: |
|
|
|
|
LOGGER.info(f"{PREFIX}Uploading checkpoint {self.model_id}") |
|
|
|
|
self._upload_model(trainer.epoch, trainer.last, is_best) |
|
|
|
|
self._timers["ckpt"] = time() # reset timer |
|
|
|
|
|
|
|
|
|
def on_train_end(self, trainer): |
|
|
|
|
# Upload final model and metrics with exponential standoff |
|
|
|
|
LOGGER.info(f"{PREFIX}Training completed successfully ✅") |
|
|
|
|
LOGGER.info(f"{PREFIX}Uploading final {self.model_id}") |
|
|
|
|
|
|
|
|
|
# hack for fetching mAP |
|
|
|
|
mAP = trainer.metrics.get("metrics/mAP50-95(B)", 0) |
|
|
|
|
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95 |
|
|
|
|
self.alive = False # stop heartbeats |
|
|
|
|
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀") |
|
|
|
|
|
|
|
|
|
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): |
|
|
|
|
# Upload a model to HUB |
|
|
|
|
file = None |
|
|
|
|
if Path(weights).is_file(): |
|
|
|
|
with open(weights, "rb") as f: |
|
|
|
|
file = f.read() |
|
|
|
|
file_param = {"best.pt" if final else "last.pt": file} |
|
|
|
|
endpoint = f"{self.api_url}/upload" |
|
|
|
|
data = {"epoch": epoch} |
|
|
|
|
if final: |
|
|
|
|
data.update({"type": "final", "map": map}) |
|
|
|
|
else: |
|
|
|
|
data.update({"type": "epoch", "isBest": bool(is_best)}) |
|
|
|
|
|
|
|
|
|
smart_request( |
|
|
|
|
endpoint, |
|
|
|
|
data=data, |
|
|
|
|
files=file_param, |
|
|
|
|
headers=self.auth_header, |
|
|
|
|
retry=10 if final else None, |
|
|
|
|
timeout=3600 if final else None, |
|
|
|
|
code=4 if final else 3, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@threaded |
|
|
|
|
def _heartbeats(self): |
|
|
|
|
def _start_heartbeat(self): |
|
|
|
|
self.alive = True |
|
|
|
|
while self.alive: |
|
|
|
|
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', |
|
|
|
|
json={ |
|
|
|
|
"agent": AGENT_NAME, |
|
|
|
|
"agentId": self.agent_id}, |
|
|
|
|
headers=self.auth_header, |
|
|
|
|
retry=0, |
|
|
|
|
code=5, |
|
|
|
|
thread=False) |
|
|
|
|
self.agent_id = r.json().get('data', {}).get('agentId', None) |
|
|
|
|
sleep(self.rate_limits['heartbeat']) |
|
|
|
|
r = smart_request( |
|
|
|
|
f"{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}", |
|
|
|
|
json={ |
|
|
|
|
"agent": AGENT_NAME, |
|
|
|
|
"agentId": self.agent_id}, |
|
|
|
|
headers=self.auth_header, |
|
|
|
|
retry=0, |
|
|
|
|
code=5, |
|
|
|
|
thread=False, |
|
|
|
|
) |
|
|
|
|
self.agent_id = r.json().get("data", {}).get("agentId", None) |
|
|
|
|
sleep(self._rate_limits["heartbeat"]) |
|
|
|
|