You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
4.0 KiB
100 lines
4.0 KiB
# Ultralytics YOLO 🚀, GPL-3.0 license |
|
import signal |
|
from pathlib import Path |
|
from time import sleep |
|
|
|
import requests |
|
|
|
from ultralytics import __version__ |
|
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request |
|
from ultralytics.yolo.utils import is_colab, threaded |
|
|
|
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' |
|
|
|
session = None |
|
|
|
|
|
class HubTrainingSession: |
|
|
|
def __init__(self, model_id, auth): |
|
self.agent_id = None # identifies which instance is communicating with server |
|
self.model_id = model_id |
|
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' |
|
self.auth_header = auth.get_auth_header() |
|
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) |
|
self.t = {} # rate limit timers (seconds) |
|
self.metrics_queue = {} # metrics queue |
|
self.alive = True # for heartbeats |
|
self.model = self._get_model() |
|
self._heartbeats() # start heartbeats |
|
signal.signal(signal.SIGTERM, self.shutdown) # register the shutdown function to be called on exit |
|
signal.signal(signal.SIGINT, self.shutdown) |
|
|
|
def shutdown(self, *args): # noqa |
|
self.alive = False # stop heartbeats |
|
|
|
def upload_metrics(self): |
|
payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"} |
|
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2) |
|
|
|
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): |
|
# Upload a model to HUB |
|
file = None |
|
if Path(weights).is_file(): |
|
with open(weights, "rb") as f: |
|
file = f.read() |
|
if final: |
|
smart_request(f'{self.api_url}/upload', |
|
data={ |
|
"epoch": epoch, |
|
"type": "final", |
|
"map": map}, |
|
files={"best.pt": file}, |
|
headers=self.auth_header, |
|
retry=10, |
|
timeout=3600, |
|
code=4) |
|
else: |
|
smart_request(f'{self.api_url}/upload', |
|
data={ |
|
"epoch": epoch, |
|
"type": "epoch", |
|
"isBest": bool(is_best)}, |
|
headers=self.auth_header, |
|
files={"last.pt": file}, |
|
code=3) |
|
|
|
def _get_model(self): |
|
# Returns model from database by id |
|
api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}" |
|
headers = self.auth_header |
|
|
|
try: |
|
r = smart_request(api_url, method="get", headers=headers, thread=False, code=0) |
|
data = r.json().get("data", None) |
|
if not data: |
|
return |
|
assert data['data'], 'ERROR: Dataset may still be processing. Please wait a minute and try again.' # RF fix |
|
self.model_id = data["id"] |
|
|
|
return data |
|
except requests.exceptions.ConnectionError as e: |
|
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e |
|
|
|
def check_disk_space(self): |
|
if not check_dataset_disk_space(self.model['data']): |
|
raise MemoryError("Not enough disk space") |
|
|
|
@threaded |
|
def _heartbeats(self): |
|
while self.alive: |
|
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', |
|
json={ |
|
"agent": AGENT_NAME, |
|
"agentId": self.agent_id}, |
|
headers=self.auth_header, |
|
retry=0, |
|
code=5, |
|
thread=False) |
|
self.agent_id = r.json().get('data', {}).get('agentId', None) |
|
sleep(self.rate_limits['heartbeat'])
|
|
|