HUB setup (#108)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/128/head
parent
c6eb6720de
commit
2bc9a5c87e
16 changed files with 632 additions and 123 deletions
@ -0,0 +1,69 @@ |
|||||||
|
import requests |
||||||
|
|
||||||
|
from ultralytics.hub.config import HUB_API_ROOT |
||||||
|
from ultralytics.hub.utils import request_with_credentials |
||||||
|
from ultralytics.yolo.utils import is_colab |
||||||
|
|
||||||
|
API_KEY_PATH = "https://hub.ultralytics.com/settings?tab=api+keys" |
||||||
|
|
||||||
|
|
||||||
|
class Auth: |
||||||
|
id_token = api_key = model_key = False |
||||||
|
|
||||||
|
def __init__(self, api_key=None): |
||||||
|
self.api_key = self._clean_api_key(api_key) |
||||||
|
self.authenticate() if self.api_key else self.auth_with_cookies() |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def _clean_api_key(key: str) -> str: |
||||||
|
"""Strip model from key if present""" |
||||||
|
separator = "_" |
||||||
|
return key.split(separator)[0] if separator in key else key |
||||||
|
|
||||||
|
def authenticate(self) -> bool: |
||||||
|
"""Attempt to authenticate with server""" |
||||||
|
try: |
||||||
|
header = self.get_auth_header() |
||||||
|
if header: |
||||||
|
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) |
||||||
|
if not r.json().get('success', False): |
||||||
|
raise ConnectionError("Unable to authenticate.") |
||||||
|
return True |
||||||
|
raise ConnectionError("User has not authenticated locally.") |
||||||
|
except ConnectionError: |
||||||
|
self.id_token = self.api_key = False # reset invalid |
||||||
|
return False |
||||||
|
|
||||||
|
def auth_with_cookies(self) -> bool: |
||||||
|
""" |
||||||
|
Attempt to fetch authentication via cookies and set id_token. |
||||||
|
User must be logged in to HUB and running in a supported browser. |
||||||
|
""" |
||||||
|
if not is_colab(): |
||||||
|
return False # Currently only works with Colab |
||||||
|
try: |
||||||
|
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") |
||||||
|
if authn.get("success", False): |
||||||
|
self.id_token = authn.get("data", {}).get("idToken", None) |
||||||
|
self.authenticate() |
||||||
|
return True |
||||||
|
raise ConnectionError("Unable to fetch browser authentication details.") |
||||||
|
except ConnectionError: |
||||||
|
self.id_token = False # reset invalid |
||||||
|
return False |
||||||
|
|
||||||
|
def get_auth_header(self): |
||||||
|
if self.id_token: |
||||||
|
return {"authorization": f"Bearer {self.id_token}"} |
||||||
|
elif self.api_key: |
||||||
|
return {"x-api-key": self.api_key} |
||||||
|
else: |
||||||
|
return None |
||||||
|
|
||||||
|
def get_state(self) -> bool: |
||||||
|
"""Get the authentication state""" |
||||||
|
return self.id_token or self.api_key |
||||||
|
|
||||||
|
def set_api_key(self, key: str): |
||||||
|
"""Get the authentication state""" |
||||||
|
self.api_key = key |
@ -0,0 +1,12 @@ |
|||||||
|
import os |
||||||
|
|
||||||
|
# Global variables |
||||||
|
REPO_URL = "https://github.com/ultralytics/yolov5.git" |
||||||
|
REPO_BRANCH = "ultralytics/HUB" # "master" |
||||||
|
|
||||||
|
ENVIRONMENT = os.environ.get("ULTRALYTICS_ENV", "production") |
||||||
|
if ENVIRONMENT == 'production': |
||||||
|
HUB_API_ROOT = "https://api.ultralytics.com" |
||||||
|
else: |
||||||
|
HUB_API_ROOT = "http://127.0.0.1:8000" |
||||||
|
print(f'Connected to development server on {HUB_API_ROOT}') |
@ -0,0 +1,121 @@ |
|||||||
|
import signal |
||||||
|
import sys |
||||||
|
from pathlib import Path |
||||||
|
from time import sleep |
||||||
|
|
||||||
|
import requests |
||||||
|
|
||||||
|
from ultralytics import __version__ |
||||||
|
from ultralytics.hub.config import HUB_API_ROOT |
||||||
|
from ultralytics.hub.utils import check_dataset_disk_space, smart_request |
||||||
|
from ultralytics.yolo.utils import LOGGER, is_colab, threaded |
||||||
|
|
||||||
|
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' |
||||||
|
|
||||||
|
session = None |
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(signum, frame): |
||||||
|
""" Confirm exit """ |
||||||
|
global hub_logger |
||||||
|
LOGGER.info(f'Signal received. {signum} {frame}') |
||||||
|
if isinstance(session, HubTrainingSession): |
||||||
|
hub_logger.alive = False |
||||||
|
del hub_logger |
||||||
|
sys.exit(signum) |
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler) |
||||||
|
signal.signal(signal.SIGINT, signal_handler) |
||||||
|
|
||||||
|
|
||||||
|
class HubTrainingSession: |
||||||
|
|
||||||
|
def __init__(self, model_id, auth): |
||||||
|
self.agent_id = None # identifies which instance is communicating with server |
||||||
|
self.model_id = model_id |
||||||
|
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' |
||||||
|
self.auth_header = auth.get_auth_header() |
||||||
|
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) |
||||||
|
self.t = {} # rate limit timers (seconds) |
||||||
|
self.metrics_queue = {} # metrics queue |
||||||
|
self.alive = True # for heartbeats |
||||||
|
self.model = self._get_model() |
||||||
|
self._heartbeats() # start heartbeats |
||||||
|
|
||||||
|
def __del__(self): |
||||||
|
# Class destructor |
||||||
|
self.alive = False |
||||||
|
|
||||||
|
def upload_metrics(self): |
||||||
|
payload = {"metrics": self.metrics_queue.copy(), "type": "metrics"} |
||||||
|
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2) |
||||||
|
|
||||||
|
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): |
||||||
|
# Upload a model to HUB |
||||||
|
file = None |
||||||
|
if Path(weights).is_file(): |
||||||
|
with open(weights, "rb") as f: |
||||||
|
file = f.read() |
||||||
|
if final: |
||||||
|
smart_request(f'{self.api_url}/upload', |
||||||
|
data={ |
||||||
|
"epoch": epoch, |
||||||
|
"type": "final", |
||||||
|
"map": map}, |
||||||
|
files={"best.pt": file}, |
||||||
|
headers=self.auth_header, |
||||||
|
retry=10, |
||||||
|
timeout=3600, |
||||||
|
code=4) |
||||||
|
else: |
||||||
|
smart_request(f'{self.api_url}/upload', |
||||||
|
data={ |
||||||
|
"epoch": epoch, |
||||||
|
"type": "epoch", |
||||||
|
"isBest": bool(is_best)}, |
||||||
|
headers=self.auth_header, |
||||||
|
files={"last.pt": file}, |
||||||
|
code=3) |
||||||
|
|
||||||
|
def _get_model(self): |
||||||
|
# Returns model from database by id |
||||||
|
api_url = f"{HUB_API_ROOT}/v1/models/{self.model_id}" |
||||||
|
headers = self.auth_header |
||||||
|
|
||||||
|
try: |
||||||
|
r = smart_request(api_url, method="get", headers=headers, thread=False, code=0) |
||||||
|
data = r.json().get("data", None) |
||||||
|
if not data: |
||||||
|
return |
||||||
|
assert data['data'], 'ERROR: Dataset may still be processing. Please wait a minute and try again.' # RF fix |
||||||
|
self.model_id = data["id"] |
||||||
|
|
||||||
|
return data |
||||||
|
except requests.exceptions.ConnectionError as e: |
||||||
|
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e |
||||||
|
|
||||||
|
def check_disk_space(self): |
||||||
|
if not check_dataset_disk_space(self.model['data']): |
||||||
|
raise MemoryError("Not enough disk space") |
||||||
|
|
||||||
|
# COMMENT: Should not be needed as HUB is now considered an integration and is in integrations_callbacks |
||||||
|
# import ultralytics.yolo.utils.callbacks.hub as hub_callbacks |
||||||
|
# @staticmethod |
||||||
|
# def register_callbacks(trainer): |
||||||
|
# for k, v in hub_callbacks.callbacks.items(): |
||||||
|
# trainer.add_callback(k, v) |
||||||
|
|
||||||
|
@threaded |
||||||
|
def _heartbeats(self): |
||||||
|
while self.alive: |
||||||
|
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', |
||||||
|
json={ |
||||||
|
"agent": AGENT_NAME, |
||||||
|
"agentId": self.agent_id}, |
||||||
|
headers=self.auth_header, |
||||||
|
retry=0, |
||||||
|
code=5, |
||||||
|
thread=False) |
||||||
|
self.agent_id = r.json().get('data', {}).get('agentId', None) |
||||||
|
sleep(self.rate_limits['heartbeat']) |
@ -0,0 +1,80 @@ |
|||||||
|
import json |
||||||
|
from time import time |
||||||
|
|
||||||
|
import torch |
||||||
|
|
||||||
|
from ultralytics.hub.utils import PREFIX, sync_analytics |
||||||
|
from ultralytics.yolo.utils import LOGGER |
||||||
|
|
||||||
|
|
||||||
|
def on_pretrain_routine_end(trainer): |
||||||
|
session = getattr(trainer, 'hub_session', None) |
||||||
|
if session: |
||||||
|
# Start timer for upload rate limit |
||||||
|
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀") |
||||||
|
session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit |
||||||
|
|
||||||
|
|
||||||
|
def on_fit_epoch_end(trainer): |
||||||
|
session = getattr(trainer, 'hub_session', None) |
||||||
|
if session: |
||||||
|
# Upload metrics after val end |
||||||
|
metrics = trainer.metrics |
||||||
|
for k, v in metrics.items(): |
||||||
|
if isinstance(v, torch.Tensor): |
||||||
|
metrics[k] = v.item() |
||||||
|
|
||||||
|
session.metrics_queue[trainer.epoch] = json.dumps(metrics) # json string |
||||||
|
if time() - session.t['metrics'] > session.rate_limits['metrics']: |
||||||
|
session.upload_metrics() |
||||||
|
session.t['metrics'] = time() # reset timer |
||||||
|
session.metrics_queue = {} # reset queue |
||||||
|
|
||||||
|
|
||||||
|
def on_model_save(trainer): |
||||||
|
session = getattr(trainer, 'hub_session', None) |
||||||
|
if session: |
||||||
|
# Upload checkpoints with rate limiting |
||||||
|
is_best = trainer.best_fitness == trainer.fitness |
||||||
|
if time() - session.t['ckpt'] > session.rate_limits['ckpt']: |
||||||
|
LOGGER.info(f"{PREFIX}Uploading checkpoint {session.model_id}") |
||||||
|
session.upload_model(trainer.epoch, trainer.last, is_best) |
||||||
|
session.t['ckpt'] = time() # reset timer |
||||||
|
|
||||||
|
|
||||||
|
def on_train_end(trainer): |
||||||
|
session = getattr(trainer, 'hub_session', None) |
||||||
|
if session: |
||||||
|
# Upload final model and metrics with exponential standoff |
||||||
|
LOGGER.info(f"{PREFIX}Training completed successfully ✅\n" |
||||||
|
f"{PREFIX}Uploading final {session.model_id}") |
||||||
|
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50(B)'], final=True) |
||||||
|
session.alive = False # stop heartbeats |
||||||
|
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀") |
||||||
|
|
||||||
|
|
||||||
|
def on_train_start(trainer): |
||||||
|
sync_analytics(trainer.args) |
||||||
|
|
||||||
|
|
||||||
|
def on_val_start(validator): |
||||||
|
sync_analytics(validator.args) |
||||||
|
|
||||||
|
|
||||||
|
def on_predict_start(predictor): |
||||||
|
sync_analytics(predictor.args) |
||||||
|
|
||||||
|
|
||||||
|
def on_export_start(exporter): |
||||||
|
sync_analytics(exporter.args) |
||||||
|
|
||||||
|
|
||||||
|
callbacks = { |
||||||
|
"on_pretrain_routine_end": on_pretrain_routine_end, |
||||||
|
"on_fit_epoch_end": on_fit_epoch_end, |
||||||
|
"on_model_save": on_model_save, |
||||||
|
"on_train_end": on_train_end, |
||||||
|
"on_train_start": on_train_start, |
||||||
|
"on_val_start": on_val_start, |
||||||
|
"on_predict_start": on_predict_start, |
||||||
|
"on_export_start": on_export_start} |
Loading…
Reference in new issue