@ -0,0 +1,112 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import math |
from timm.models.layers import trunc_normal_, DropPath, Mlp |
import torch.nn as nn |
from utils.misc import is_pow2n |
_BN = None |
class UNetBlock2x(nn.Module): |
def __init__(self, cin, cout, cmid, last_act=True): |
super().__init__() |
if cmid == 0: |
c_mid = cin |
elif cmid == 1: |
c_mid = (cin + cout) // 2 |
self.b = nn.Sequential( |
nn.Conv2d(cin, c_mid, 3, 1, 1, bias=False), _BN(c_mid), nn.ReLU6(inplace=True), |
nn.Conv2d(c_mid, cout, 3, 1, 1, bias=False), _BN(cout), (nn.ReLU6(inplace=True) if last_act else nn.Identity()), |
) |
def forward(self, x): |
return self.b(x) |
class DecoderConv(nn.Module): |
def __init__(self, cin, cout, double, heavy, cmid): |
super().__init__() |
self.up = nn.ConvTranspose2d(cin, cin, kernel_size=4 if double else 2, stride=2, padding=1 if double else 0, bias=True) |
ls = [UNetBlock2x(cin, (cin if i != heavy[1]-1 else cout), cmid=cmid, last_act=i != heavy[1]-1) for i in range(heavy[1])] |
self.conv = nn.Sequential(*ls) |
def forward(self, x): |
x = self.up(x) |
return self.conv(x) |
class LightDecoder(nn.Module): |
def __init__(self, decoder_fea_dim, upsample_ratio, double=False, heavy=None, cmid=0, sbn=False): |
global _BN |
_BN = nn.SyncBatchNorm if sbn else nn.BatchNorm2d |
super().__init__() |
self.fea_dim = decoder_fea_dim |
if heavy is None: |
heavy = [0, 1] |
heavy[1] = max(1, heavy[1]) |
self.double_bool = double |
self.heavy = heavy |
self.cmid = cmid |
self.sbn = sbn |
assert is_pow2n(upsample_ratio) |
n = round(math.log2(upsample_ratio)) |
channels = [self.fea_dim // 2**i for i in range(n+1)] |
self.dec = nn.ModuleList([ |
DecoderConv(cin, cout, double, heavy, cmid) for (cin, cout) in zip(channels[:-1], channels[1:]) |
]) |
self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True) |
self.initialize() |
def forward(self, to_dec): |
x = 0 |
for i, d in enumerate(self.dec): |
if i < len(to_dec) and to_dec[i] is not None: |
x = x + to_dec[i] |
x = self.dec[i](x) |
return self.proj(x) |
def num_para(self): |
tot = sum(p.numel() for p in self.parameters()) |
para1 = para2 = 0 |
for m in self.dec.modules(): |
if isinstance(m, nn.ConvTranspose2d): |
para1 += sum(p.numel() for p in m.parameters()) |
elif isinstance(m, nn.Conv2d): |
para2 += sum(p.numel() for p in m.parameters()) |
return f'#para: {tot/1e6:.2f} (dconv={para1/1e6:.2f}, conv={para2/1e6:.2f}, ot={(tot-para1-para2)/1e6:.2f})' |
def extra_repr(self) -> str: |
return f'fea_dim={self.fea_dim}, dbl={self.double_bool}, heavy={self.heavy}, cmid={self.cmid}, sbn={self.sbn}' |
def initialize(self): |
for m in self.modules(): |
if isinstance(m, nn.Linear): |
trunc_normal_(m.weight, std=.02) |
if m.bias is not None: |
nn.init.constant_(m.bias, 0) |
elif isinstance(m, nn.Embedding): |
trunc_normal_(m.weight, std=.02) |
if m.padding_idx is not None: |
||||[m.padding_idx].zero_() |
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)): |
nn.init.constant_(m.bias, 0) |
nn.init.constant_(m.weight, 1.0) |
elif isinstance(m, nn.Conv2d): |
trunc_normal_(m.weight, std=.02) |
if m.bias is not None: |
nn.init.constant_(m.bias, 0) |
elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
if m.bias is not None: |
nn.init.constant_(m.bias, 0.) |
@ -0,0 +1,142 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import functools |
import os |
from typing import List |
from typing import Union |
import torch |
import torch.distributed as tdist |
import torch.multiprocessing as mp |
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' |
__initialized = False |
def initialized(): |
return __initialized |
def initialize(backend='nccl'): |
# ref: |
if mp.get_start_method(allow_none=True) is None: |
mp.set_start_method('spawn') |
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count() |
local_rank = global_rank % num_gpus |
torch.cuda.set_device(local_rank) |
tdist.init_process_group(backend=backend) # 不要 init_method='env://' |
global __rank, __local_rank, __world_size, __device, __initialized |
__local_rank = local_rank |
__rank, __world_size = tdist.get_rank(), tdist.get_world_size() |
__device = torch.empty(1).cuda().device |
__initialized = True |
assert tdist.is_initialized(), 'torch.distributed is not initialized!' |
def get_rank(): |
return __rank |
def get_local_rank(): |
return __local_rank |
def get_world_size(): |
return __world_size |
def get_device(): |
return __device |
def is_master(): |
return __rank == 0 |
def is_local_master(): |
return __local_rank == 0 |
def parallelize(net, syncbn=False): |
if syncbn: |
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) |
net = net.cuda() |
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
return net |
def new_group(ranks: List[int]): |
return tdist.new_group(ranks=ranks) |
def barrier(): |
tdist.barrier() |
def allreduce(t: torch.Tensor) -> None: |
if not t.is_cuda: |
cu = t.detach().cuda() |
tdist.all_reduce(cu) |
t.copy_(cu.cpu()) |
else: |
tdist.all_reduce(t) |
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: |
if not t.is_cuda: |
t = t.cuda() |
ls = [torch.empty_like(t) for _ in range(__world_size)] |
tdist.all_gather(ls, t) |
if cat: |
ls =, dim=0) |
return ls |
def broadcast(t: torch.Tensor, src_rank) -> None: |
if not t.is_cuda: |
cu = t.detach().cuda() |
tdist.broadcast(cu, src=src_rank) |
t.copy_(cu.cpu()) |
else: |
tdist.broadcast(t, src=src_rank) |
def dist_fmt_vals(val, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: |
ts = torch.zeros(__world_size) |
ts[__rank] = val |
allreduce(ts) |
if fmt is None: |
return ts |
return [fmt % v for v in ts.cpu().numpy().tolist()] |
def master_only(func): |
@functools.wraps(func) |
def wrapper(*args, **kwargs): |
force = kwargs.pop('force', False) |
if force or is_master(): |
ret = func(*args, **kwargs) |
else: |
ret = None |
barrier() |
return ret |
return wrapper |
def local_master_only(func): |
@functools.wraps(func) |
def wrapper(*args, **kwargs): |
force = kwargs.pop('force', False) |
if force or is_local_master(): |
ret = func(*args, **kwargs) |
else: |
ret = None |
barrier() |
return ret |
return wrapper |
@ -0,0 +1,207 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import torch |
import torch.nn as nn |
from timm.models.layers import DropPath |
_cur_active: torch.Tensor = None # B1ff |
def _get_active_ex_or_ii(H, returning_active_ex=True): |
downsample_raito = H // _cur_active.shape[-1] |
active_ex = _cur_active.repeat_interleave(downsample_raito, 2).repeat_interleave(downsample_raito, 3) |
return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi |
def sp_conv_forward(self, x: torch.Tensor): |
x = super(type(self), self).forward(x) |
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) # (BCHW) *= (B1HW) |
return x |
def sp_bn_forward(self, x: torch.Tensor): |
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False) |
bhwc = x.permute(0, 2, 3, 1) |
nc = bhwc[ii] |
nc = super(type(self), self).forward(nc) # BN1d forward |
bchw = torch.zeros_like(bhwc) |
bchw[ii] = nc |
bchw = bchw.permute(0, 3, 1, 2) |
return bchw |
class SparseConv2d(nn.Conv2d): |
forward = sp_conv_forward |
class SparseMaxPooling(nn.MaxPool2d): |
forward = sp_conv_forward |
class SparseAvgPooling(nn.AvgPool2d): |
forward = sp_conv_forward |
class SparseBatchNorm2d(nn.BatchNorm1d): |
forward = sp_bn_forward |
class SparseSyncBatchNorm2d(nn.SyncBatchNorm): |
forward = sp_bn_forward |
class SparseConvNeXtLayerNorm(nn.LayerNorm): |
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. |
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
with shape (batch_size, channels, height, width). |
""" |
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True): |
if data_format not in ["channels_last", "channels_first"]: |
raise NotImplementedError |
super().__init__(normalized_shape, eps, elementwise_affine=True) |
self.data_format = data_format |
self.sparse = sparse |
def forward(self, x): |
if x.ndim == 4: # BHWC |
if self.data_format == "channels_last": |
if self.sparse: |
ii = _get_active_ex_or_ii(H=x.shape[1], returning_active_ex=False) |
nc = x[ii] |
nc = super(SparseConvNeXtLayerNorm, self).forward(nc) |
x = torch.zeros_like(x) |
x[ii] = nc |
return x |
else: |
return super(SparseConvNeXtLayerNorm, self).forward(x) |
else: # channels_first |
if self.sparse: |
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False) |
bhwc = x.permute(0, 2, 3, 1) |
nc = bhwc[ii] |
nc = super(SparseConvNeXtLayerNorm, self).forward(nc) |
x = torch.zeros_like(bhwc) |
x[ii] = nc |
return x.permute(0, 3, 1, 2) |
else: |
u = x.mean(1, keepdim=True) |
s = (x - u).pow(2).mean(1, keepdim=True) |
x = (x - u) / torch.sqrt(s + self.eps) |
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
return x |
else: # BLC or BC |
if self.sparse: |
raise NotImplementedError |
else: |
return super(SparseConvNeXtLayerNorm, self).forward(x) |
def __repr__(self): |
return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})' |
class SparseConvNeXtBlock(nn.Module): |
r""" ConvNeXt Block. There are two equivalent implementations: |
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) |
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back |
We use (2) as we find it slightly faster in PyTorch |
Args: |
dim (int): Number of input channels. |
drop_path (float): Stochastic depth rate. Default: 0.0 |
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
""" |
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7): |
super().__init__() |
self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv |
self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse) |
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers |
self.act = nn.GELU() |
self.pwconv2 = nn.Linear(4 * dim, dim) |
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), |
requires_grad=True) if layer_scale_init_value > 0 else None |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
self.sparse = sparse |
def forward(self, x): |
input = x |
x = self.dwconv(x) |
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) |
x = self.norm(x) |
x = self.pwconv1(x) |
x = self.act(x) |
x = self.pwconv2(x) |
if self.gamma is not None: |
x = self.gamma * x |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) |
if self.sparse: |
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) |
x = input + self.drop_path(x) |
return x |
def __repr__(self): |
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})' |
class SparseEncoder(nn.Module): |
def __init__(self, cnn, input_size, downsample_raito, encoder_fea_dim, verbose=False, sbn=False): |
super(SparseEncoder, self).__init__() |
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn) |
self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim |
@staticmethod |
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False): |
oup = m |
if isinstance(m, nn.Conv2d): |
m: nn.Conv2d |
bias = m.bias is not None |
oup = SparseConv2d( |
m.in_channels, m.out_channels, |
kernel_size=m.kernel_size, stride=m.stride, padding=m.padding, |
dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode, |
) |
|||| |
if bias: |
|||| |
elif isinstance(m, nn.MaxPool2d): |
m: nn.MaxPool2d |
oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode) |
elif isinstance(m, nn.AvgPool2d): |
m: nn.AvgPool2d |
oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override) |
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): |
m: nn.BatchNorm2d |
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats) |
|||| |
|||| |
|||| |
|||| |
|||| |
if hasattr(m, "qconfig"): |
oup.qconfig = m.qconfig |
elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm): |
m: nn.LayerNorm |
oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps) |
|||| |
|||| |
elif isinstance(m, (nn.Conv1d,)): |
raise NotImplementedError |
for name, child in m.named_children(): |
oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn)) |
del m |
return oup |
def forward(self, x, pyramid): |
return self.sp_cnn(x, pyramid=pyramid) |
@ -0,0 +1,87 @@ |
#!/usr/bin/python3 |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import argparse |
import functools |
import os |
import socket |
import subprocess |
import sys |
from typing import List |
echo = lambda info: os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"') |
os_system = functools.partial(, shell=True) |
os_system_get_stdout = lambda cmd:, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') |
def os_system_get_stdout_stderr(cmd): |
sp =, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') |
def __find_free_port(): |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
# Binding to port 0 will cause the OS to find an available port for us |
sock.bind(("", 0)) |
port = sock.getsockname()[1] |
sock.close() |
# NOTE: there is still a chance the port could be taken by other processes. |
return port |
if __name__ == '__main__': |
parser = argparse.ArgumentParser(description='PyTorch Distributed Launcher') |
parser.add_argument('--main_py_relpath', type=str, default='', |
help='specify launcher script.') |
# distributed environment |
parser.add_argument('--num_nodes', type=int, default=1) |
parser.add_argument('--ngpu_per_node', type=int, default=1) |
parser.add_argument('--node_rank', type=int, default=0, |
help='node rank, ranged from 0 to [dist_num_nodes]-1') |
parser.add_argument('--master_address', type=str, default='', |
help='master address for distributed communication') |
parser.add_argument('--master_port', type=int, default=30001, |
help='master port for distributed communication') |
# other args |
known_args, other_args = parser.parse_known_args() |
other_args: List[str] |
echo(f'[other_args received by]: {other_args}') |
main_args = other_args[-1] |
main_args = '='.join(map(str.strip, main_args.split('='))) |
main_args = main_args.split(' ') |
for i, a in enumerate(main_args): |
if len(a) and '=' not in a: |
main_args[i] = f'{a}=1' |
other_args[-1] = ' '.join(main_args) |
echo(f'[final other_args]: {other_args[-1]}') |
if known_args.num_nodes > 1: |
os.environ['NPROC_PER_NODE'] = str(known_args.ngpu_per_node) |
cmd = ( |
f'python3 -m torch.distributed.launch' |
f' --nproc_per_node={known_args.ngpu_per_node}' |
f' --nnodes={known_args.num_nodes}' |
f' --node_rank={known_args.node_rank}' |
f' --master_addr={known_args.master_address}' |
f' --master_port={known_args.master_port}' |
f' {known_args.main_py_relpath}' |
f' {" ".join(other_args)}' |
) |
else: |
cmd = ( |
f'python3 -m torch.distributed.launch' |
f' --nproc_per_node={known_args.ngpu_per_node}' |
f' --master_port={known_args.master_port}' |
f' {known_args.main_py_relpath}' |
f' {" ".join(other_args)}' |
) |
exit_code =, shell=True) |
sys.exit(exit_code) |
@ -0,0 +1,156 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import math |
import sys |
import time |
from functools import partial |
from typing import List |
import torch |
from torch.nn.parallel import DistributedDataParallel |
from import DataLoader |
import dist |
import encoder |
from decoder import LightDecoder |
from models import build_sparse_encoder |
from sampler import DistInfiniteBatchSampler, worker_init_fn |
from spark import SparK |
from utils import meta, misc, optim |
from utils.imagenet import build_imagenet |
from utils.lr_control import lr_wd_annealing, get_param_groups |
def main_pt(): |
args: meta.Args = meta.init_dist_and_get_args() |
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}') |
print(f'initial args:\n{str(args)}') |
args.log_epoch() |
# build data |
print(f'[build data for pre-training] ...\n') |
dataset_train, _ = build_imagenet('pt', args.data_path, args.data_set, args.input_size, eval_crop_pct=None, rrc=args.rrc) |
data_loader_train = DataLoader( |
dataset=dataset_train, num_workers=args.num_workers, pin_memory=True, |
batch_sampler=DistInfiniteBatchSampler( |
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed, |
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(), |
), worker_init_fn=worker_init_fn |
) |
itrt_train = iter(data_loader_train) |
iters_train = len(data_loader_train) |
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}') |
# build models (encoder, decoder, and other components) |
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False) |
dec = LightDecoder(args.dec_dim, enc.downsample_raito, double=args.double, heavy=args.hea, cmid=args.cmid, sbn=args.sbn) |
spark = SparK( |
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, mask_ratio2=args.mask2, uniform=args.uni, |
||||,, dense_loss=args.den, loss_l2=args.loss_l2, |
en_de_norm=args.en_de_norm, en_de_lin=args.en_de_lin, sbn=args.sbn,, |
) |
print(f'[PT model] model = {spark}\n') |
|||| |
model: DistributedDataParallel = DistributedDataParallel(spark, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
model_without_ddp: SparK = model.module |
# build optimizer and lr_scheduler |
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'pos_embed', 'mask_token', 'gamma'}, lr_scale=0) |
opt_clz = { |
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True), |
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)), |
'lamb': partial(optim.TimmLAMB, betas=(0.9, args.ada), max_grad_norm=args.clip), |
}[args.opt] |
optimizer = opt_clz(params=param_groups,, weight_decay=0.0) |
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n') |
# try to resume |
next_ep, performance_desc = misc.load_checkpoint(args.resume, model_without_ddp, optimizer) if len(args.resume) else (0, '[no performance_desc]') |
if next_ep >= args.ep: |
# load from a complete checkpoint file |
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}') |
else: |
# perform pre-training |
start_time = time.time() |
min_loss = 1e9 |
print(f'[PT start] from ep{next_ep}') |
for ep in range(next_ep, args.ep): |
if hasattr(itrt_train, 'set_epoch'): |
itrt_train.set_epoch(ep) |
stats, (sec, remain_time, finish_time) = pre_train_one_ep(ep, args, itrt_train, iters_train, model, optimizer) |
last_loss = stats['last_loss'] |
min_loss = min(min_loss, last_loss) |
performance_desc = f'{min_loss:.4f} {last_loss:.4f}' |
print(f' [*] [ep{ep}] Min/Last Recon Loss: {performance_desc}, Remain: {remain_time}, Finish: {finish_time}') |
args.cur_phase = 'PT' |
args.cur_ep = f'{ep+1}/{args.ep}' |
args.remain_time, args.finish_time = str(remain_time), str(finish_time) |
args.last_loss = last_loss |
args.log_epoch() |
misc.save_checkpoint(f'ckpt-last.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) |
# finish pre-training |
print('\n\n') |
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - start_time) / 60 / 60:.1f}h') |
print('\n\n') |
misc.save_checkpoint(f'ckpt-final.pth', args, args.ep-1, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) |
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time())) |
args.log_epoch() |
def pre_train_one_ep(ep, args, itrt_train, iters_train, model: DistributedDataParallel, optimizer): |
model.train() |
me = misc.MetricLogger(delimiter=' ') |
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}')) |
header = f'[PT] Epoch: [{ep:3d}/{args.ep}]' |
optimizer.zero_grad() |
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
late_clipping = args.clip > 0 and hasattr(optimizer, 'global_grad_norm') |
if early_clipping: |
params_req_grad = [p for p in model.parameters() if p.requires_grad] |
# for every batch do: |
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)): |
# adjust lr and wd |
g_it = it + ep*iters_train |
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer,, args.wd, args.wde, g_it, args.wp_ep*iters_train, args.ep*iters_train) |
# forward and backward |
inp =, non_blocking=True) |
SparK.forward |
active_ex, rec, loss = model(inp) |
optimizer.zero_grad() |
loss.backward() |
loss = loss.item() |
if not math.isfinite(loss): |
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True) |
sys.exit(-1) |
# optimize |
grad_norm = None |
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip) |
optimizer.step() |
if late_clipping: grad_norm = optimizer.global_grad_norm |
torch.cuda.synchronize() |
# log |
me.update(last_loss=loss) |
me.update(max_lr=max_lr) |
if grad_norm is not None: |
me.update(orig_norm=grad_norm) |
me.synchronize_between_processes() |
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds((args.ep-1-ep) * (iters_train+10)) |
if __name__ == '__main__': |
main_pt() |
@ -0,0 +1,83 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import torch |
from timm import create_model |
from timm.loss import SoftTargetCrossEntropy |
from timm.models.layers import drop |
from models.convnext import ConvNeXt |
from models.resnet import ResNet |
_import_resnets_for_timm_registration = (ResNet,) |
# log more |
def _ex_repr(self): |
return ', '.join( |
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) |
for k, v in vars(self).items() |
if not k.startswith('_') and k != 'training' |
and not isinstance(v, (torch.nn.Module, torch.Tensor)) |
) |
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): |
if hasattr(clz, 'extra_repr'): |
clz.extra_repr = _ex_repr |
else: |
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' |
model_alias_to_fullname = { |
'res50': 'resnet50', |
'res101': 'resnet101', |
'res152': 'resnet152', |
'res200': 'resnet200', |
'cnxS': 'convnext_small', |
'cnxB': 'convnext_base', |
'cnxL': 'convnext_large', |
} |
model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()} |
pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel |
'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048], |
'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048], |
'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048], |
'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048], |
'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768], |
'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024], |
'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536], |
} |
for v in pre_train_d.values(): |
v[0]['pretrained'] = False |
v[0]['num_classes'] = 0 |
v[0]['global_pool'] = '' |
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): |
from encoder import SparseEncoder |
kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name] |
if drop_path_rate != 0: |
kwargs['drop_path_rate'] = drop_path_rate |
print(f'[sparse_cnn] model kwargs={kwargs}') |
cnn = create_model(name, **kwargs) |
if hasattr(cnn, 'global_pool'): |
if callable(cnn.global_pool): |
cnn.global_pool = torch.nn.Identity() |
elif isinstance(cnn.global_pool, str): |
cnn.global_pool = '' |
if not isinstance(downsample_raito, int) or not isinstance(fea_dim, int): |
with torch.no_grad(): |
cnn.eval() |
o = cnn(torch.rand(1, 3, input_size, input_size)) |
downsample_raito = input_size // o.shape[-1] |
fea_dim = o.shape[1] |
cnn.train() |
print(f'[sparse_cnn] downsample_raito={downsample_raito}, fea_dim={fea_dim}') |
return SparseEncoder(cnn, input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, verbose=verbose, sbn=sbn) |
@ -0,0 +1,212 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
# |
# This file is basically a copy to: |
import torch |
import torch.nn as nn |
from timm.models.layers import trunc_normal_ |
from timm.models.registry import register_model |
from encoder import SparseConvNeXtBlock, SparseConvNeXtLayerNorm |
class ConvNeXt(nn.Module): |
r""" ConvNeXt |
A PyTorch impl of : `A ConvNet for the 2020s` - |
|||| |
Args: |
in_chans (int): Number of input image channels. Default: 3 |
num_classes (int): Number of classes for classification head. Default: 1000 |
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] |
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] |
drop_path_rate (float): Stochastic depth rate. Default: 0. |
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. |
""" |
def __init__(self, in_chans=3, num_classes=1000, |
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., |
layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg', |
sparse=True, |
): |
super().__init__() |
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers |
stem = nn.Sequential( |
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), |
SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse) |
) |
self.downsample_layers.append(stem) |
for i in range(3): |
downsample_layer = nn.Sequential( |
SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse), |
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), |
) |
self.downsample_layers.append(downsample_layer) |
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks |
self.drop_path_rate = drop_path_rate |
self.layer_scale_init_value = layer_scale_init_value |
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
cur = 0 |
for i in range(4): |
stage = nn.Sequential( |
*[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j], |
layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])] |
) |
self.stages.append(stage) |
cur += depths[i] |
self.depths = depths |
self.apply(self._init_weights) |
if num_classes > 0: |
self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False) # final norm layer for LE/FT; should not be sparse |
self.fc = nn.Linear(dims[-1], num_classes) |
# # todo: perform this outside |
# # todo: perform this outside |
else: |
self.norm = nn.Identity() |
self.fc = nn.Identity() |
self.with_pooling = len(global_pool) > 0 |
def _init_weights(self, m): |
if isinstance(m, (nn.Conv2d, nn.Linear)): |
trunc_normal_(m.weight, std=.02) |
nn.init.constant_(m.bias, 0) |
def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4 |
ls = [] |
for i in range(4): |
x = self.downsample_layers[i](x) |
x = self.stages[i](x) |
if pyramid: |
ls.append(x) |
if pyramid: |
for i in range(len(ls)-pyramid-1, -1, -1): |
del ls[i] |
return [None] * (4 - pyramid) + ls |
else: |
if self.with_pooling: |
x = x.mean([-2, -1]) # global average pooling, (N, C, H, W) -> (N, C) |
return x |
def forward(self, x, pyramid=0): |
if pyramid == 0: |
x = self.forward_features(x, pyramid=pyramid) |
x = self.fc(self.norm(x)) |
return x |
else: |
return self.forward_features(x, pyramid=pyramid) |
def get_classifier(self): |
return self.fc |
def extra_repr(self): |
return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}' |
def get_layer_id_and_scale_exp(self, para_name: str): |
N = 12 if self.depths[-2] > 9 else 6 |
if para_name.startswith("downsample_layers"): |
stage_id = int(para_name.split('.')[1]) |
if stage_id == 0: |
layer_id = 0 |
elif stage_id == 1 or stage_id == 2: |
layer_id = stage_id + 1 |
else: # stage_id == 3: |
layer_id = N |
elif para_name.startswith("stages"): |
stage_id = int(para_name.split('.')[1]) |
block_id = int(para_name.split('.')[2]) |
if stage_id == 0 or stage_id == 1: |
layer_id = stage_id + 1 |
elif stage_id == 2: |
layer_id = 3 + block_id // 3 |
else: # stage_id == 3: |
layer_id = N |
else: |
layer_id = N + 1 # after backbone |
return layer_id, N + 1 - layer_id |
model_urls = { |
"convnext_tiny_1k": "", |
"convnext_small_1k": "", |
"convnext_base_1k": "", |
"convnext_large_1k": "", |
"convnext_tiny_22k": "", |
"convnext_small_22k": "", |
"convnext_base_22k": "", |
"convnext_large_22k": "", |
"convnext_xlarge_22k": "", |
} |
@register_model |
def convnext_tiny(pretrained=False, in_22k=False, **kwargs): |
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) |
if pretrained: |
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) |
model.load_state_dict(checkpoint["model"]) |
return model |
@register_model |
def convnext_small(pretrained=False, in_22k=False, **kwargs): |
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) |
if pretrained: |
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
model.load_state_dict(checkpoint["model"]) |
return model |
@register_model |
def convnext_base(pretrained=False, in_22k=False, **kwargs): |
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) |
if pretrained: |
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
model.load_state_dict(checkpoint["model"]) |
return model |
@register_model |
def convnext_large(pretrained=False, in_22k=False, **kwargs): |
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) |
if pretrained: |
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
model.load_state_dict(checkpoint["model"]) |
return model |
@register_model |
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): |
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) |
if pretrained: |
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" |
url = model_urls['convnext_xlarge_22k'] |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") |
model.load_state_dict(checkpoint["model"]) |
return model |
if __name__ == '__main__': |
from timm.models import create_model |
c = create_model('convnext_small', sparse=False) |
with torch.no_grad(): |
x = torch.rand(2, 3, 224, 224) |
print(c(x).shape) |
print([None if f is None else f.shape for f in c(x, pyramid=1)]) |
print([None if f is None else f.shape for f in c(x, pyramid=2)]) |
print([None if f is None else f.shape for f in c(x, pyramid=3)]) |
print([None if f is None else f.shape for f in c(x, pyramid=4)]) |
@ -0,0 +1,105 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import math |
import torch.nn.functional as F |
from timm.models.resnet import ResNet |
def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4 |
x = self.conv1(x) |
x = self.bn1(x) |
x = self.act1(x) |
x = self.maxpool(x) |
ls = [] |
x = self.layer1(x) |
if pyramid: ls.append(x) |
x = self.layer2(x) |
if pyramid: ls.append(x) |
x = self.layer3(x) |
if pyramid: ls.append(x) |
x = self.layer4(x) |
if pyramid: ls.append(x) |
if pyramid: |
for i in range(len(ls)-pyramid-1, -1, -1): |
del ls[i] |
return [None] * (4 - pyramid) + ls |
else: |
return x |
def forward(self, x, pyramid=0): |
if pyramid == 0: |
x = self.forward_features(x, pyramid=pyramid) |
x = self.global_pool(x) |
if self.drop_rate: |
x = F.dropout(x, p=float(self.drop_rate), |
x = self.fc(x) |
return x |
else: |
return self.forward_features(x, pyramid=pyramid) |
def resnets_get_layer_id_and_scale_exp(self, para_name: str): |
# stages: |
# 50 : [3, 4, 6, 3] |
# 101 : [3, 4, 23, 3] |
# 152 : [3, 8, 36, 3] |
# 200 : [3, 24, 36, 3] |
# eca269d: [3, 30, 48, 8] |
L2, L3 = len(self.layer2), len(self.layer3) |
if L2 == 4 and L3 == 6: |
blk2, blk3 = 2, 3 |
elif L2 == 4 and L3 == 23: |
blk2, blk3 = 2, 3 |
elif L2 == 8 and L3 == 36: |
blk2, blk3 = 4, 4 |
elif L2 == 24 and L3 == 36: |
blk2, blk3 = 4, 4 |
elif L2 == 30 and L3 == 48: |
blk2, blk3 = 5, 6 |
else: |
raise NotImplementedError |
N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5) |
N = 2 + N2 + N3 |
if para_name.startswith('layer'): # 1, 2, 3, 4, 5 |
stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1]) |
if stage_id == 1: |
layer_id = 1 |
elif stage_id == 2: |
layer_id = 2 + block_id // blk2 # 2, 3 |
elif stage_id == 3: |
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11 |
else: # == 4 |
layer_id = N # r50: 6 r101: 12 |
elif para_name.startswith('fc.'): |
layer_id = N+1 # r50: 7 r101: 13 |
else: |
layer_id = 0 |
return layer_id, N+1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0 |
ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp |
ResNet.forward_features = forward_features |
ResNet.forward = forward |
if __name__ == '__main__': |
import torch |
from timm.models import create_model |
r = create_model('resnet50') |
with torch.no_grad(): |
print(r(torch.rand(2, 3, 224, 224)).shape) |
print(r(torch.rand(2, 3, 224, 224), pyramid=1)) |
print(r(torch.rand(2, 3, 224, 224), pyramid=2)) |
print(r(torch.rand(2, 3, 224, 224), pyramid=3)) |
print(r(torch.rand(2, 3, 224, 224), pyramid=4)) |
@ -0,0 +1,167 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import math |
import random |
import numpy as np |
import torch |
from import Sampler |
import dist |
def worker_init_fn(worker_id): |
# |
worker_seed = torch.initial_seed() % 2 ** 32 |
np.random.seed(worker_seed) |
random.seed(worker_seed) |
class RASampler(Sampler): |
"""Sampler that restricts data loading to a subset of the dataset for distributed, |
with repeated augmentation. |
It ensures that different each augmented version of a sample will be visible to a |
different process (GPU). |
Heavily based on ''. |
This is borrowed from the DeiT Repo: |
|||| |
""" |
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): |
if num_replicas is None: |
num_replicas = dist.get_world_size() |
if rank is None: |
rank = dist.get_rank() |
self.dataset = dataset |
self.num_replicas = num_replicas |
self.rank = rank |
self.epoch = 0 |
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) |
self.total_size = self.num_samples * self.num_replicas |
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) |
self.shuffle = shuffle |
self.seed = seed |
self.repetitions = repetitions |
def __iter__(self): |
if self.shuffle: |
# Deterministically shuffle based on epoch |
g = torch.Generator() |
g.manual_seed(self.seed + self.epoch) |
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
else: |
indices = list(range(len(self.dataset))) |
# Add extra samples to make it evenly divisible |
indices = [ele for ele in indices for i in range(self.repetitions)] |
indices += indices[: (self.total_size - len(indices))] |
assert len(indices) == self.total_size |
# Subsample |
indices = indices[self.rank : self.total_size : self.num_replicas] |
assert len(indices) == self.num_samples |
return iter(indices[: self.num_selected_samples]) |
def __len__(self): |
return self.num_selected_samples |
def set_epoch(self, epoch): |
self.epoch = epoch |
class InfiniteBatchSampler(Sampler): |
def __init__(self, dataset_len, batch_size, seed=0, filling=False, shuffle=True, drop_last=False): |
self.dataset_len = dataset_len |
self.batch_size = batch_size |
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size |
self.max_p = self.iters_per_ep * batch_size |
self.filling = filling |
self.shuffle = shuffle |
self.epoch = 0 |
self.seed = seed |
self.indices = self.gener_indices() |
def gener_indices(self): |
if self.shuffle: |
g = torch.Generator() |
g.manual_seed(self.epoch + self.seed) |
indices = torch.randperm(self.dataset_len, generator=g).numpy() |
else: |
indices = torch.arange(self.dataset_len).numpy() |
tails = self.batch_size - (self.dataset_len % self.batch_size) |
if tails != self.batch_size and self.filling: |
tails = indices[:tails] |
np.random.shuffle(indices) |
indices = np.concatenate((indices, tails)) |
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop) |
# noinspection PyTypeChecker |
return tuple(indices.tolist()) |
def __iter__(self): |
self.epoch = 0 |
while True: |
self.epoch += 1 |
p, q = 0, 0 |
while p < self.max_p: |
q = p + self.batch_size |
yield self.indices[p:q] |
p = q |
if self.shuffle: |
self.indices = self.gener_indices() |
def __len__(self): |
return self.iters_per_ep |
class DistInfiniteBatchSampler(InfiniteBatchSampler): |
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, repeated_aug=0, filling=False, shuffle=True): |
# from torchvision.models import ResNet50_Weights |
# RA sampler: |
assert glb_batch_size % world_size == 0 |
self.world_size, self.rank = world_size, rank |
self.dataset_len = dataset_len |
self.glb_batch_size = glb_batch_size |
self.batch_size = glb_batch_size // world_size |
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size |
self.filling = filling |
self.shuffle = shuffle |
self.repeated_aug = repeated_aug |
self.epoch = 0 |
self.seed = seed |
self.indices = self.gener_indices() |
def gener_indices(self): |
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 |
if self.shuffle: |
g = torch.Generator() |
g.manual_seed(self.epoch + self.seed) |
global_indices = torch.randperm(self.dataset_len, generator=g) |
if self.repeated_aug > 1: |
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p] |
else: |
global_indices = torch.arange(self.dataset_len) |
filling = global_max_p - global_indices.shape[0] |
if filling > 0 and self.filling: |
global_indices =, global_indices[:filling])) |
global_indices = tuple(global_indices.numpy().tolist()) |
seps = torch.linspace(0, len(global_indices), self.world_size + 1, |
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] |
self.max_p = len(local_indices) |
return local_indices |
if __name__ == '__main__': |
W = 16 |
for rk in range(W): |
ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices() |
print(rk, len(ind)) |
@ -0,0 +1,49 @@ |
#!/usr/bin/env bash |
# an example to do pre-training: (not that `/path/to/imagenet` should contain directories named `train` and `val`) |
# > cd /path/to/SparK |
# > bash ./scripts/ experiment_name /path/to/imagenet --num_nodes=1 --ngpu_per_node=8 --node_rank=0 --master_address= --master_port=30000 --model=res50 --ep=400 |
####### template begins ####### |
SCRIPTS_DIR=$(cd $(dirname $0); pwd) |
cd "${SCRIPTS_DIR}" |
cd ../ |
SPARK_DIR=$(pwd) |
shopt -s expand_aliases |
alias python=python3 |
alias to_scripts_dir='cd "${SCRIPTS_DIR}"' |
alias to_spark_dir='cd "${SPARK_DIR}"' |
alias print='echo "$(date +"[%m-%d %H:%M:%S]") (>"' |
function mkd() { |
mkdir -p "$1" >/dev/null 2>&1 |
} |
####### template ends ####### |
EXP_DIR="${SPARK_DIR}/output_${EXP_NAME}" |
mkd "${EXP_DIR}" |
print "===================== Args =====================" |
print "EXP_NAME: ${EXP_NAME}" |
print "DATA_PATH: ${DATA_PATH}" |
print "EXP_DIR: ${EXP_DIR}" |
print "[other_args sent to]: ${*:3}" |
print "================================================" |
print "" |
print "============== Pretraining starts ==============" |
to_spark_dir |
python \ |
--main_py_relpath \ |
--exp_name "${EXP_NAME}" \ |
--data_path "${DATA_PATH}" \ |
--exp_dir "${EXP_DIR}" \ |
"${*:3}" |
print "============== Pretraining ends ==============" |
@ -0,0 +1,313 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import random |
from pprint import pformat |
from typing import List |
import numpy as np |
import torch |
import torch.nn as nn |
from timm.models.layers import trunc_normal_ |
import encoder |
from decoder import LightDecoder |
class SparK(nn.Module): |
def __init__( |
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder, |
mask_ratio=0.6, mask_ratio2=0.6, uniform=False, |
using_pe=True, pix_norm=0, dense_loss=False, loss_l2=True, |
en_de_norm='bn', en_de_lin=True, sbn=False, pyramid=1, # 1 for single-scale pre-training; 4 for full-scale pre-training |
): |
super().__init__() |
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito |
self.downsample_raito = downsample_raito |
fmap_size = input_size // downsample_raito |
self.fmap_size = fmap_size |
if mask_ratio != mask_ratio2 and not uniform: # with an extra active site |
k = 1 / fmap_size**2 |
mask_ratio = min(1, mask_ratio / (1-k)) |
mask_ratio2 = min(1, mask_ratio2 / (1-k)) |
self.mask_ratio = (mask_ratio, mask_ratio2) |
self.ratios = torch.tensor([self.mask_ratio[0], self.mask_ratio[1], (self.mask_ratio[0] + self.mask_ratio[1]) / 2]) |
self.uniform = uniform |
self.len_keep = round(fmap_size * fmap_size * (1-mask_ratio)) |
self.pix_norm = int(pix_norm) |
self.sparse_encoder = sparse_encoder |
self.dense_decoder = dense_decoder |
self.sbn = sbn |
self.pyramid = pyramid |
en_de_norm = en_de_norm.lower() |
self.en_de_norm_str = en_de_norm |
self.en_de_lin_bool = en_de_lin |
self.en_de_norms = nn.ModuleList() |
self.en_de_lins = nn.ModuleList() |
self.using_pe = using_pe |
self.pos_embeds = nn.ParameterList() |
self.mask_tokens = nn.ParameterList() |
fea, d_fea, fmap = self.sparse_encoder.fea_dim, self.dense_decoder.fea_dim, fmap_size |
for i in range(self.pyramid): |
if en_de_norm == 'bn': |
n = (encoder.SparseSyncBatchNorm2d if sbn else encoder.SparseBatchNorm2d)(fea) |
elif en_de_norm == 'ln': |
n = encoder.SparseConvNeXtLayerNorm(fea, data_format='channels_first', sparse=True) |
else: |
n = nn.Identity() |
self.en_de_norms.append(n) |
kernel_size = 1 if i <= 0 else 3 |
l = nn.Conv2d(fea, d_fea, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True) |
print(f'[mid, py={self.pyramid}][edl {i}]: k={kernel_size}, #para = {sum(x.numel() for x in l.parameters())/1e6:.2f}') |
if i == 0 and fea == d_fea: |
l = nn.Identity() |
self.en_de_lins.append(l) |
if self.using_pe: |
p = torch.from_numpy(get_2d_sincos_pos_embed(fea, fmap)).float() |
p = p.reshape(1, fmap, fmap, fea).permute(0, 3, 1, 2).contiguous() |
p = nn.Parameter(p, requires_grad=False) |
self.pos_embeds.append(p) |
p = nn.Parameter(torch.zeros(1, fea, 1, 1)) |
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02) |
self.mask_tokens.append(p) |
fea //= 2 |
d_fea //= 2 |
fmap *= 2 |
print(f'[mid, py={self.pyramid}][mask_tokens]: {tuple(p.numel() for p in self.mask_tokens)}') |
self.loss_l2, self.dense_loss = loss_l2, dense_loss |
m = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1) |
s = torch.tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1) |
self.register_buffer('imn_m', m) |
self.register_buffer('imn_s', s) |
# self.register_buffer('norm_black', (torch.ones(1, 3, input_size, input_size) * 0.45 - m) / s) |
self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size)) |
self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ... |
def mask(self, shape, device, generator=None): |
B, C, H, W = shape |
f = self.fmap_size |
if self.mask_ratio[0] == self.mask_ratio[1]: |
len_keep = self.len_keep |
elif self.uniform: |
r = random.uniform(self.mask_ratio[0], self.mask_ratio[1]) |
len_keep = round(f * f * (1-r)) |
else: |
i1, i2, i3, i4 = np.linspace(0, B, 4, dtype=int).tolist() |
l1, l2, l3 = i2-i1, i3-i2, i4-i3 |
r1, r2, r3 = self.ratios[torch.randperm(3, generator=generator)].tolist() |
r = torch.tensor([r1]*l1 + [r2]*l2 + [r3]*l3, device=device).view(-1, 1, 1) |
active = torch.rand(B, f, f, device=device, generator=generator) >= r |
rr, cc = torch.randint(low=0, high=f, size=(2, B), generator=generator).unbind(0) |
active[torch.arange(B), rr, cc] = True # an extra active site |
return active.unsqueeze_(1) |
idx = torch.rand(B, f*f, generator=generator).argsort(dim=1) |
idx = idx[:, :len_keep].to(device) # (B, len_keep) |
return torch.zeros(B, f*f, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, f, f) |
def forward(self, raw_inp: torch.Tensor, active=None): |
inp_bchw = raw_inp |
# spatial mask |
if active is None: |
active: torch.BoolTensor = self.mask(inp_bchw.shape, inp_bchw.device) # (B, 1, f, f) |
encoder._cur_active = active |
active_ex = active.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W) |
masked_bchw = inp_bchw * active_ex |
# get hierarchical encoded sparse features (a list containing four feature maps) |
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, pyramid=self.pyramid) |
fea_bcffs.reverse() # from the smallest feature map to the largest |
cur_active = active |
to_dec = [] |
for i, bcff in enumerate(fea_bcffs): # from the smallest feature map to the largest |
if bcff is not None: |
# fill in empty positions with [mask] embeddings |
bcff = self.en_de_norms[i](bcff) |
mask_tokens = self.mask_tokens[i].expand_as(bcff) |
if self.using_pe: |
mask_tokens = mask_tokens + self.pos_embeds[i].expand_as(bcff) |
bcff = torch.where(cur_active.expand_as(bcff), bcff, mask_tokens) |
bcff: torch.Tensor = self.en_de_lins[i](bcff) |
to_dec.append(bcff) |
cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) # dilate the mask map |
# decode and reconstruct |
rec_bchw = self.dense_decoder(to_dec) |
# calc loss |
mean, var, spatial_loss = self.spatial_loss(raw_inp, rec_bchw, active) |
return active_ex, rec_bchw, spatial_loss |
def spatial_loss(self, inp, rec, active): # active: (B, 1, f, f) |
mean = var = None |
if self.pix_norm == 2: |
mean, var = inp.mean(dim=(2, 3), keepdim=True), None |
rec = rec + mean |
inp = self.patchify(inp) |
rec = self.patchify(rec) |
# (B, L=fmap_size**2, N=downsample_raito**2 * C) |
if self.pix_norm == 1: |
mean = inp.mean(dim=-1, keepdim=True) |
var = (inp.var(dim=-1, keepdim=True) + 1e-6)**.5 |
inp = (inp - mean) / var |
loss_spa = (rec-inp)**2 if self.loss_l2 else (rec-inp).abs() |
if self.dense_loss: |
return mean, var, loss_spa.mean() # mean loss on all patches |
else: |
loss_spa = loss_spa.mean(dim=2, keepdim=False) # (B, L, C) => (B, L) |
non_active = active.logical_not().int().view(active.shape[0], -1) # (B, 1, f, f) => (B, L) |
return mean, var, loss_spa.mul_(non_active).sum() / (non_active.sum() + 1e-8) # mean loss on removed patches |
def patchify(self, bchw): |
p = self.downsample_raito |
h = w = self.fmap_size |
B, C = bchw.shape[:2] |
bchw = bchw.reshape(shape=(B, C, h, p, w, p)) |
bchw = torch.einsum('bchpwq->bhwpqc', bchw) |
bln = bchw.reshape(shape=(B, h*w, p**2 * C)) # (B, f*f, downsample_raito*downsample_raito*3) |
return bln |
def unpatchify(self, bln): |
p = self.downsample_raito |
h = w = self.fmap_size |
B, C = bln.shape[0], bln.shape[-1] // p**2 |
bln = bln.reshape(shape=(B, h, w, p, p, C)) |
bln = torch.einsum('bhwpqc->bchpwq', bln) |
bchw = bln.reshape(shape=(B, C, h * p, w * p)) |
return bchw |
def denorm_for_vis(self, x, clamp): |
x = x * self.imn_s |
x += self.imn_m |
if clamp: |
x = torch.clamp(x, 0, 1) |
return x |
def __repr__(self): |
return ( |
f'\n' |
f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n' |
f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}\n' |
f'[SparK.dec]: {self.dense_decoder.num_para()}' |
) |
def get_config(self): |
return { |
# self |
'mask_ratio': self.mask_ratio[0], 'mask_ratio2': self.mask_ratio[1], 'uniform': self.uniform, |
'using_pe': self.using_pe, 'pix_norm': self.pix_norm, |
'dense_loss': self.dense_loss, 'loss_l2': self.loss_l2, |
'en_de_norm': self.en_de_norm_str, 'en_de_lin': self.en_de_lin_bool, 'sbn': self.sbn, 'pyramid': self.pyramid, |
# enc |
'input_size': self.sparse_encoder.input_size, |
# dec |
'dec_fea_dim': self.dense_decoder.fea_dim, 'double': self.dense_decoder.double_bool, 'heavy': self.dense_decoder.heavy, |
} |
def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False): |
state = super(SparK, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
if with_config: |
state['config'] = self.get_config() # todo: 似乎会引起DDP broadcast err?? |
return state |
def load_state_dict(self, state_dict, strict=True): |
config = state_dict.pop('config', None) |
incompatible_keys = super(SparK, self).load_state_dict(state_dict, strict=strict) |
if config is not None: |
for k, v in self.get_config().items(): |
if config.get(k, None) != v: |
err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})' |
if strict: |
raise AttributeError(err) |
else: |
print(err) |
return incompatible_keys |
def _make_divisible(v, divisor=8, min_value=None): |
""" |
This function is taken from the original tf repo. |
It ensures that all layers have a channel number that is divisible by 8 |
It can be seen here: |
|||| |
""" |
if min_value is None: |
min_value = divisor |
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
# Make sure that round down does not go down by more than 10%. |
if new_v < 0.9 * v: |
new_v += divisor |
return new_v |
def get_2d_sincos_pos_embed(embed_dim, grid_size): |
""" |
grid_size: int of the grid height and width |
return: |
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
""" |
grid_h = np.arange(grid_size, dtype=np.float32) |
grid_w = np.arange(grid_size, dtype=np.float32) |
grid = np.meshgrid(grid_w, grid_h) # here w goes first |
grid = np.stack(grid, axis=0) |
grid = grid.reshape([2, 1, grid_size, grid_size]) |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
return pos_embed |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
assert embed_dim % 2 == 0 |
# use half of dimensions to encode grid_h |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) |
return emb |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
""" |
embed_dim: output dimension for each position |
pos: a list of positions to be encoded: size (M,) |
out: (M, D) |
""" |
assert embed_dim % 2 == 0 |
omega = np.arange(embed_dim // 2, dtype=np.float) |
omega /= embed_dim / 2. |
omega = 1. / 10000**omega # (D/2,) |
pos = pos.reshape(-1) # (M,) |
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product |
emb_sin = np.sin(out) # (M, D/2) |
emb_cos = np.cos(out) # (M, D/2) |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) |
return emb |
if __name__ == '__main__': |
SparK.test_mask() |
SparK.test_align() |
@ -0,0 +1,172 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import os |
from typing import Any, Callable, Optional, Tuple |
import PIL.Image as PImage |
import torch |
from import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform |
from import transforms_imagenet_eval |
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS |
from torchvision.transforms import transforms |
import dist |
try: |
from torchvision.transforms import InterpolationMode |
interpolation = InterpolationMode.BICUBIC |
except: |
import PIL |
interpolation = PIL.Image.BICUBIC |
def pil_loader(path): |
# open path as file to avoid ResourceWarning ( |
with open(path, 'rb') as f: img: PImage.Image ='RGB') |
return img |
class ImageNetDataset(DatasetFolder): |
def __init__( |
self, |
root: str, |
train: bool, |
transform: Optional[Callable] = None, |
target_transform: Optional[Callable] = None, |
is_valid_file: Optional[Callable[[str], bool]] = None, |
max_cls_id: int = 1000, |
only=-1, |
): |
for postfix in (os.path.sep, 'train', 'val'): |
if root.endswith(postfix): |
root = root[:-len(postfix)] |
root = os.path.join(root, 'train' if train else 'val') |
super(ImageNetDataset, self).__init__( |
root, |
# loader=ImageLoader(train), |
loader=pil_loader, |
extensions=IMG_EXTENSIONS if is_valid_file is None else None, |
transform=transform, target_transform=target_transform, is_valid_file=is_valid_file |
) |
if only > 0: |
g = torch.Generator() |
g.manual_seed(0) |
idx = torch.randperm(len(self.samples), generator=g).numpy().tolist() |
ws = dist.get_world_size() |
res = (max_cls_id * only) % ws |
more = 0 if res == 0 else (ws - res) |
max_total = max_cls_id * only + more |
if (max_total // ws) % 2 == 1: |
more += ws |
max_total += ws |
d = {c: [] for c in range(max_cls_id)} |
max_len = {c: only for c in range(max_cls_id)} |
for c in range(max_cls_id-more, max_cls_id): |
max_len[c] += 1 |
total = 0 |
for i in idx: |
path, target = self.samples[i] |
if len(d[target]) < max_len[target]: |
d[target].append((path, target)) |
total += 1 |
if total == max_total: |
break |
sp = [] |
[sp.extend(l) for l in d.values()] |
print(f'[ds] more={more}, len(sp)={len(sp)}') |
self.samples = tuple(sp) |
self.targets = tuple([s[1] for s in self.samples]) |
else: |
self.samples = tuple(filter(lambda item: item[-1] < max_cls_id, self.samples)) |
self.targets = tuple([s[1] for s in self.samples]) |
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
path, target = self.samples[index] |
sample = self.loader(path) |
if self.transform is not None: |
sample = self.transform(sample) |
if self.target_transform is not None: |
target = self.target_transform(target) |
return sample, target |
def build_imagenet(mode, data_path, data_set, img_size, eval_crop_pct=None, rrc=0.3, aa='rand-m7-mstd0.5', re_prob=0.0, colorj=0.4): |
norm = transforms.Normalize(mean=mean, std=std) |
if img_size >= 384: |
trans_val = transforms.Compose([ |
transforms.Resize((img_size, img_size), interpolation=interpolation), |
transforms.ToTensor(), |
norm, |
]) |
else: |
trans_val = transforms_imagenet_eval( |
img_size=img_size, interpolation='bicubic', crop_pct=eval_crop_pct, |
mean=mean, std=std |
) |
mode = mode.lower() |
if mode == 'pt': |
trans_train = transforms.Compose([ |
transforms.RandomResizedCrop(img_size, scale=(rrc, 1.0), interpolation=interpolation), |
transforms.RandomHorizontalFlip(), |
transforms.ToTensor(), |
norm, |
]) |
elif mode == 'le': |
trans_train = transforms.Compose([ |
transforms.RandomResizedCrop(img_size, interpolation=interpolation), |
transforms.RandomHorizontalFlip(), |
transforms.ToTensor(), |
norm, |
]) |
else: |
trans_train = create_transform( |
is_training=True, |
input_size=img_size, |
auto_augment=aa, |
interpolation='bicubic', |
re_prob=re_prob, |
re_mode='pixel', |
re_count=1, |
color_jitter=colorj, |
mean=mean, std=std, |
) |
if data_path.endswith(os.path.sep): |
data_path = data_path[:-len(os.path.sep)] |
for postfix in ('train', 'val'): |
if data_path.endswith(postfix): |
data_path = data_path[:-len(postfix)] |
if data_set == 'imn': |
dataset_train = ImageNetDataset(root=data_path, transform=trans_train, train=True) |
dataset_val = ImageNetDataset(root=data_path, transform=trans_val, train=False) |
num_classes = 1000 |
else: |
raise NotImplementedError |
print_transform(trans_train, '[train]') |
print_transform(trans_val, '[val]') |
return dataset_train, dataset_val |
def print_transform(transform, s): |
print(f'Transform {s} = ') |
for t in transform.transforms: |
print(t) |
print('---------------------------\n') |
@ -0,0 +1,73 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import math |
from pprint import pformat |
def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it): |
wp_it = round(wp_it) |
if cur_it < wp_it: |
cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it |
else: |
ratio = (cur_it - wp_it) / (max_it-1 - wp_it) |
cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio)) |
ratio = cur_it / (max_it-1) |
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio)) |
inf = 1e6 |
min_lr, max_lr = inf, -1 |
min_wd, max_wd = inf, -1 |
for param_group in optimizer.param_groups: |
param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned |
max_lr = max(max_lr, param_group['lr']) |
min_lr = min(min_lr, param_group['lr']) |
param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1) |
max_wd = max(max_wd, param_group['weight_decay']) |
if param_group['weight_decay'] > 0: |
min_wd = min(min_wd, param_group['weight_decay']) |
if min_lr == inf: min_lr = -1 |
if min_wd == inf: min_wd = -1 |
return min_lr, max_lr, min_wd, max_wd |
def get_param_groups(model, nowd_keys=(), lr_scale=0.0): |
with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale < 1 |
print(f'[get_ft_param_groups][lr decay] with_lr_scale={with_lr_scale}, ft_lr_scale={lr_scale}') |
para_groups, para_groups_dbg = {}, {} |
for name, para in model.named_parameters(): |
if not para.requires_grad: |
continue # frozen weights |
if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys): |
wd_scale, group_name = 0., 'no_decay' |
else: |
wd_scale, group_name = 1., 'decay' |
if with_lr_scale: |
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name) |
group_name = f'layer{layer_id}_' + group_name |
cur_lr_scale = lr_scale ** scale_exp |
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]' |
else: |
cur_lr_scale = 1 |
dbg = f'[no scale]' |
if group_name not in para_groups: |
para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': cur_lr_scale} |
para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': dbg} |
para_groups[group_name]['params'].append(para) |
para_groups_dbg[group_name]['params'].append(name) |
for g in para_groups_dbg.values(): |
g['params'] = pformat(', '.join(g['params']), width=200) |
print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n') |
return list(para_groups.values()) |
@ -0,0 +1,163 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import json |
import os |
import re |
import sys |
from tap import Tap |
import dist |
line_sep = f'\n{"=" * 80}\n' |
class Args(Tap): |
# environment |
local_rank: int # useless |
exp_name: str |
data_path: str |
exp_dir: str |
log_txt_name: str = '/some/path/like/this/log.txt' |
resume: str = '' |
seed: int = 1 |
device: str = 'cpu' |
# key MIM hp |
mask: float = 0.6 |
mask2: float = -1 |
uni: bool = False |
pe: bool = False |
pn: int = 1 |
py: int = 4 |
# other MIM hp |
den: bool = False |
loss_l2: bool = True |
en_de_norm: str = 'bn' |
en_de_lin: bool = True |
# encoder |
model: str = 'res50' |
model_alias: str = 'res50' |
input_size: int = 224 |
sbn: bool = True |
# decoder |
dec_dim: int = 512 # [could be changed in ``] |
double: bool = True |
hea: str = '0_1' |
cmid: int = 0 |
# pre-training hyperparameters |
glb_batch_size: int = 0 |
batch_size: int = 0 # batch size per GPU |
dp: float = 0.0 |
base_lr: float = 2e-4 |
lr: float = None |
wd: float = 0.04 |
wde: float = 0.2 |
ep: int = 1600 |
wp_ep: int = 40 |
clip: int = 5. |
opt: str = '' |
ada: float = 0. |
# data hyperparameters |
data_set: str = 'imn' |
rrc: float = 0.67 |
bs: int = 4096 |
num_workers: int = 8 |
# would be added during runtime |
cmd: str = '' |
commit_id: str = '' |
commit_msg: str = '' |
last_loss = 1e9 # [would be changed in ``] |
cur_phase: str = '' # [would be changed in ``] |
cur_ep: str = '' # [would be changed in ``] |
remain_time: str = '' # [would be changed in ``] |
finish_time: str = '' # [would be changed in ``] |
first_logging: bool = True |
@property |
def is_convnext(self): |
return 'convnext' in self.model or 'cnx' in self.model |
@property |
def is_resnet(self): |
return 'res' in self.model or 'res' in self.model_alias |
def __str__(self): |
return re.sub(r"(\[LE-FT\]:\s*)('\s+')?", r'\1', super(Args, self).__str__()) |
def log_epoch(self): |
if not dist.is_local_master(): |
return |
if self.first_logging: |
self.first_logging = False |
with open(self.log_txt_name, 'w') as fp: |
json.dump({ |
'name': self.exp_name, 'cmd': self.cmd, 'commit_id': self.commit_id, |
'model': self.model, 'opt': self.opt, |
}, fp) |
print('', end='\n', file=fp) |
with open(self.log_txt_name, 'a') as fp: |
json.dump({ |
'cur': self.cur_phase, 'cur_ep': self.cur_ep, |
'last_L': self.last_loss, |
'rema': self.remain_time, 'fini': self.finish_time, |
}, fp) |
def init_dist_and_get_args(): |
from utils import misc |
from models import model_alias_to_fullname, model_fullname_to_alias |
# initialize |
args = Args(explicit_bool=True).parse_args() |
misc.init_distributed_environ(exp_dir=args.exp_dir) |
# update args |
args.cmd = ' '.join(sys.argv[1:]) |
args.commit_id = os.popen(f'git rev-parse HEAD').read().strip() |
args.commit_msg = os.popen(f'git log -1').read().strip().splitlines()[-1].strip() |
if args.model in model_alias_to_fullname.keys(): |
args.model = model_alias_to_fullname[args.model] |
args.model_alias = model_fullname_to_alias[args.model] |
args.device = dist.get_device() |
args.batch_size = // dist.get_world_size() |
args.glb_batch_size = args.batch_size * dist.get_world_size() |
if args.is_resnet: |
args.opt = args.opt or 'lamb' |
args.ada = args.ada or 0.95 |
if args.is_convnext: |
args.opt = args.opt or 'lamb' |
args.ada = args.ada or 0.999 |
args.en_de_norm = 'ln' |
args.opt = args.opt.lower() |
|||| = args.base_lr * args.glb_batch_size / 256 |
args.wde = args.wde or args.wd |
if args.mask2 < 0: |
args.mask2 = args.mask |
args.mask, args.mask2 = min(args.mask, args.mask2), max(args.mask, args.mask2) |
if <= 0: |
|||| = 1 |
args.hea = list(map(int, args.hea.split('_'))) |
args.log_txt_name = os.path.join(args.exp_dir, 'log.txt') |
return args |
@ -0,0 +1,273 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
import datetime |
import functools |
import os |
import subprocess |
import sys |
import time |
from collections import defaultdict, deque |
from typing import Iterator |
import numpy as np |
import pytz |
import torch |
import torch.distributed as tdist |
import dist |
os_system = functools.partial(, shell=True) |
os_system_get_stdout = lambda cmd:, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') |
def os_system_get_stdout_stderr(cmd): |
sp =, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') |
def is_pow2n(x): |
return x > 0 and (x & (x - 1) == 0) |
def time_str(): |
return'Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') |
def init_distributed_environ(exp_dir): |
dist.initialize() |
dist.barrier() |
import torch.backends.cudnn as cudnn |
cudnn.benchmark = True |
cudnn.deterministic = False |
_set_print_only_on_master_proc(is_master=dist.is_local_master()) |
if dist.is_local_master() and len(exp_dir): |
sys.stdout, sys.stderr = _SyncPrintToFile(exp_dir, stdout=True), _SyncPrintToFile(exp_dir, stdout=False) |
def save_checkpoint(fname, args, epoch, performance_desc, model_without_ddp_state, optimizer_state): |
checkpoint_path = os.path.join(args.exp_dir, fname) |
if dist.is_local_master(): |
to_save = { |
'args': str(args), |
'arch': args.model, |
'epoch': epoch, |
'performance_desc': performance_desc, |
'module': model_without_ddp_state, |
'optimizer': optimizer_state, |
} |
||||, checkpoint_path) |
dist.barrier() |
def load_checkpoint(fname, model_without_ddp, optimizer): |
print(f'[try to resume from file `{fname}`]') |
checkpoint = torch.load(fname, map_location='cpu') |
next_ep, performance_desc = checkpoint['epoch'] + 1, checkpoint['performance_desc'] |
missing, unexpected = model_without_ddp.load_state_dict(checkpoint['module'], strict=False) |
print(f'[load_checkpoint] missing_keys={missing}') |
print(f'[load_checkpoint] unexpected_keys={unexpected}') |
print(f'[load_checkpoint] next_ep={next_ep}, performance_desc={performance_desc}') |
if 'optimizer' in checkpoint: |
optimizer.load_state_dict(checkpoint['optimizer']) |
return next_ep, performance_desc |
class SmoothedValue(object): |
"""Track a series of values and provide access to smoothed values over a |
window or the global series average. |
""" |
def __init__(self, window_size=20, fmt=None): |
if fmt is None: |
fmt = "{median:.4f} ({global_avg:.4f})" |
self.deque = deque(maxlen=window_size) |
|||| = 0.0 |
self.count = 0 |
self.fmt = fmt |
def update(self, value, n=1): |
self.deque.append(value) |
self.count += n |
|||| += value * n |
def synchronize_between_processes(self): |
""" |
Warning: does not synchronize the deque! |
""" |
t = torch.tensor([self.count,], dtype=torch.float64, device='cuda') |
tdist.barrier() |
tdist.all_reduce(t) |
t = t.tolist() |
self.count = int(t[0]) |
|||| = t[1] |
@property |
def median(self): |
d = torch.tensor(list(self.deque)) |
return d.median().item() |
@property |
def avg(self): |
d = torch.tensor(list(self.deque), dtype=torch.float32) |
return d.mean().item() |
@property |
def global_avg(self): |
return / self.count |
@property |
def max(self): |
return max(self.deque) |
@property |
def value(self): |
return self.deque[-1] |
def time_preds(self, counts): |
remain_secs = counts * self.median |
remain_time = datetime.timedelta(seconds=round(remain_secs)) |
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)) |
return remain_secs, str(remain_time), finish_time |
def __str__(self): |
return self.fmt.format( |
median=self.median, |
avg=self.avg, |
global_avg=self.global_avg, |
max=self.max, |
value=self.value) |
class MetricLogger(object): |
def __init__(self, delimiter="\t"): |
self.meters = defaultdict(SmoothedValue) |
self.delimiter = delimiter |
def update(self, **kwargs): |
for k, v in kwargs.items(): |
if v is None: |
continue |
if isinstance(v, torch.Tensor): |
v = v.item() |
assert isinstance(v, (float, int)) |
self.meters[k].update(v) |
def __getattr__(self, attr): |
if attr in self.meters: |
return self.meters[attr] |
if attr in self.__dict__: |
return self.__dict__[attr] |
raise AttributeError("'{}' object has no attribute '{}'".format( |
type(self).__name__, attr)) |
def __str__(self): |
loss_str = [] |
for name, meter in self.meters.items(): |
loss_str.append( |
"{}: {}".format(name, str(meter)) |
) |
return self.delimiter.join(loss_str) |
def synchronize_between_processes(self): |
for meter in self.meters.values(): |
meter.synchronize_between_processes() |
def add_meter(self, name, meter): |
self.meters[name] = meter |
def log_every(self, max_iters, itrt, print_freq, header=None): |
print_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist()) |
if not header: |
header = '' |
start_time = time.time() |
end = time.time() |
self.iter_time = SmoothedValue(fmt='{avg:.4f}') |
self.data_time = SmoothedValue(fmt='{avg:.4f}') |
space_fmt = ':' + str(len(str(max_iters))) + 'd' |
log_msg = [ |
header, |
'[{0' + space_fmt + '}/{1}]', |
'eta: {eta}', |
'{meters}', |
'time: {time}', |
'data: {data}' |
] |
log_msg = self.delimiter.join(log_msg) |
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): |
for i in range(max_iters): |
obj = next(itrt) |
self.data_time.update(time.time() - end) |
yield obj |
self.iter_time.update(time.time() - end) |
if i in print_iters: |
eta_seconds = self.iter_time.global_avg * (max_iters - i) |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
print(log_msg.format( |
i, max_iters, eta=eta_string, |
meters=str(self), |
time=str(self.iter_time), data=str(self.data_time))) |
end = time.time() |
else: |
for i, obj in enumerate(itrt): |
self.data_time.update(time.time() - end) |
yield obj |
self.iter_time.update(time.time() - end) |
if i in print_iters: |
eta_seconds = self.iter_time.global_avg * (max_iters - i) |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
print(log_msg.format( |
i, max_iters, eta=eta_string, |
meters=str(self), |
time=str(self.iter_time), data=str(self.data_time))) |
end = time.time() |
total_time = time.time() - start_time |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
print('{} Total time: {} ({:.3f} s / it)'.format( |
header, total_time_str, total_time / max_iters)) |
def _set_print_only_on_master_proc(is_master): |
import builtins as __builtin__ |
builtin_print = __builtin__.print |
def prt(msg, *args, **kwargs): |
force = kwargs.pop('force', False) |
clean = kwargs.pop('clean', False) |
deeper = kwargs.pop('deeper', False) |
if is_master or force: |
if not clean: |
f_back = sys._getframe().f_back |
if deeper and f_back.f_back is not None: |
f_back = f_back.f_back |
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] |
msg = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}' |
builtin_print(msg, *args, **kwargs) |
__builtin__.print = prt |
class _SyncPrintToFile(object): |
def __init__(self, exp_dir, stdout=True): |
self.terminal = sys.stdout if stdout else sys.stderr |
fname = os.path.join(exp_dir, 'stdout.txt' if stdout else 'stderr.txt') |
self.log = open(fname, 'w') |
self.log.flush() |
def write(self, message): |
self.terminal.write(message) |
self.log.write(message) |
self.log.flush() |
def flush(self): |
self.terminal.flush() |
self.log.flush() |
@ -0,0 +1,160 @@ |
# Copyright (c) ByteDance, Inc. and its affiliates. |
# All rights reserved. |
# |
# This source code is licensed under the license found in the |
# LICENSE file in the root directory of this source tree. |
# |
# This file is basically a copy to: |
""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb |
This optimizer code was adapted from the following (starting with latest) |
* |
* |
* |
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is |
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. |
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. |
Original copyrights for above sources are below. |
Modifications Copyright 2021 Ross Wightman |
""" |
import math |
import torch |
from torch.optim.optimizer import Optimizer |
class TimmLAMB(Optimizer): |
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB |
reference: |
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. |
Arguments: |
params (iterable): iterable of parameters to optimize or dicts defining parameter groups. |
lr (float, optional): learning rate. (default: 1e-3) |
betas (Tuple[float, float], optional): coefficients used for computing |
running averages of gradient and its norm. (default: (0.9, 0.999)) |
eps (float, optional): term added to the denominator to improve |
numerical stability. (default: 1e-8) |
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
grad_averaging (bool, optional): whether apply (1-beta2) to grad when |
calculating running averages of gradient. (default: True) |
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) |
trust_clip (bool): enable LAMBC trust ratio clipping (default: False) |
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 |
weight decay parameter (default: False) |
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: |
|||| |
.. _On the Convergence of Adam and Beyond: |
|||| |
""" |
def __init__( |
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, |
weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False): |
defaults = dict( |
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, |
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, |
trust_clip=trust_clip, always_adapt=always_adapt) |
super().__init__(params, defaults) |
print(f'[lamb1] max_grad_norm={max_grad_norm}') |
self.global_grad_norm = 0 |
@torch.no_grad() |
def step(self, closure=None): |
"""Performs a single optimization step. |
Arguments: |
closure (callable, optional): A closure that reevaluates the model |
and returns the loss. |
""" |
loss = None |
if closure is not None: |
with torch.enable_grad(): |
loss = closure() |
device = self.param_groups[0]['params'][0].device |
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly |
global_grad_norm = torch.zeros(1, device=device) |
for group in self.param_groups: |
for p in group['params']: |
if p.grad is None: |
continue |
grad = p.grad |
if grad.is_sparse: |
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') |
global_grad_norm.add_(grad.pow(2).sum()) |
global_grad_norm = torch.sqrt(global_grad_norm) |
self.global_grad_norm = global_grad_norm.item() |
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) |
clip_global_grad_norm = 1 / torch.where( |
global_grad_norm > max_grad_norm, |
global_grad_norm / max_grad_norm, |
one_tensor) |
for group in self.param_groups: |
bias_correction = 1 if group['bias_correction'] else 0 |
beta1, beta2 = group['betas'] |
grad_averaging = 1 if group['grad_averaging'] else 0 |
beta3 = 1 - beta1 if grad_averaging else 1.0 |
# assume same step across group now to simplify things |
# per parameter step can be easily support by making it tensor, or pass list into kernel |
if 'step' in group: |
group['step'] += 1 |
else: |
group['step'] = 1 |
if bias_correction: |
bias_correction1 = 1 - beta1 ** group['step'] |
bias_correction2 = 1 - beta2 ** group['step'] |
else: |
bias_correction1, bias_correction2 = 1.0, 1.0 |
for p in group['params']: |
if p.grad is None: |
continue |
grad = p.grad.mul_(clip_global_grad_norm) |
state = self.state[p] |
# State initialization |
if len(state) == 0: |
# Exponential moving average of gradient valuesa |
state['exp_avg'] = torch.zeros_like(p) |
# Exponential moving average of squared gradient values |
state['exp_avg_sq'] = torch.zeros_like(p) |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
# Decay the first and second moment running average coefficient |
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t |
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) |
update = (exp_avg / bias_correction1).div_(denom) |
weight_decay = group['weight_decay'] |
if weight_decay != 0: |
update.add_(p, alpha=weight_decay) |
if weight_decay != 0 or group['always_adapt']: |
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are |
# excluded from weight decay, unless always_adapt == True, then always enabled. |
w_norm = p.norm(2.0) |
g_norm = update.norm(2.0) |
# FIXME nested where required since logical and/or not working in PT XLA |
trust_ratio = torch.where( |
w_norm > 0, |
torch.where(g_norm > 0, w_norm / g_norm, one_tensor), |
one_tensor, |
) |
if group['trust_clip']: |
# LAMBC trust clipping, upper bound fixed at one |
trust_ratio = torch.minimum(trust_ratio, one_tensor) |
update.mul_(trust_ratio) |
p.add_(update, alpha=-group['lr']) |
return loss |
